diff --git a/server/channelserver/char_save_locks.go b/server/channelserver/char_save_locks.go new file mode 100644 index 000000000..d4aee0d86 --- /dev/null +++ b/server/channelserver/char_save_locks.go @@ -0,0 +1,25 @@ +package channelserver + +import "sync" + +// CharacterLocks provides per-character mutexes to serialize save operations. +// This prevents concurrent saves for the same character from racing, which +// could defeat corruption detection (e.g. house tier snapshot vs. write). +// +// The underlying sync.Map grows lazily — entries are created on first access +// and never removed (character IDs are bounded and reused across sessions). +type CharacterLocks struct { + m sync.Map // map[uint32]*sync.Mutex +} + +// Lock acquires the mutex for the given character and returns an unlock function. +// Usage: +// +// unlock := s.server.charSaveLocks.Lock(charID) +// defer unlock() +func (cl *CharacterLocks) Lock(charID uint32) func() { + val, _ := cl.m.LoadOrStore(charID, &sync.Mutex{}) + mu := val.(*sync.Mutex) + mu.Lock() + return mu.Unlock +} diff --git a/server/channelserver/handlers_character.go b/server/channelserver/handlers_character.go index 49638e2bc..b8a84bfca 100644 --- a/server/channelserver/handlers_character.go +++ b/server/channelserver/handlers_character.go @@ -1,6 +1,8 @@ package channelserver import ( + "bytes" + "crypto/sha256" "database/sql" "errors" "fmt" @@ -18,9 +20,10 @@ const ( saveBackupInterval = 30 * time.Minute // minimum time between backups ) -// GetCharacterSaveData loads a character's save data from the database. +// GetCharacterSaveData loads a character's save data from the database and +// verifies its integrity checksum when one is stored. func GetCharacterSaveData(s *Session, charID uint32) (*CharacterSaveData, error) { - id, savedata, isNew, name, err := s.server.charRepo.LoadSaveData(charID) + id, savedata, isNew, name, storedHash, err := s.server.charRepo.LoadSaveDataWithHash(charID) if err != nil { if errors.Is(err, sql.ErrNoRows) { s.logger.Error("No savedata found", zap.Uint32("charID", charID)) @@ -49,6 +52,22 @@ func GetCharacterSaveData(s *Session, charID uint32) (*CharacterSaveData, error) return nil, err } + // Verify integrity checksum if one was stored with this save. + // A nil hash means the character was saved before checksums were introduced, + // so we skip verification (the next save will compute and store the hash). + if storedHash != nil { + computedHash := sha256.Sum256(saveData.decompSave) + if !bytes.Equal(storedHash, computedHash[:]) { + s.logger.Error("Savedata integrity check failed: hash mismatch", + zap.Uint32("charID", charID), + zap.Binary("stored_hash", storedHash), + zap.Binary("computed_hash", computedHash[:]), + ) + // TODO: attempt recovery from savedata_backups here + return nil, errors.New("savedata integrity check failed") + } + } + saveData.updateStructWithSaveData() return saveData, nil @@ -85,56 +104,63 @@ func (save *CharacterSaveData) Save(s *Session) error { save.compSave = save.decompSave } - // Time-gated rotating backup: snapshot the previous compressed savedata - // before overwriting, but only if enough time has elapsed since the last - // backup. This keeps storage bounded (3 slots × blob size per character) - // while providing recovery points. + // Compute integrity hash over the decompressed save. + hash := sha256.Sum256(save.decompSave) + + // Build the atomic save params — character data, house data, hash, and + // optionally a backup snapshot, all in one transaction. + params := SaveAtomicParams{ + CharID: save.CharID, + CompSave: save.compSave, + Hash: hash[:], + HR: save.HR, + GR: save.GR, + IsFemale: save.Gender, + WeaponType: save.WeaponType, + WeaponID: save.WeaponID, + HouseTier: save.HouseTier, + HouseData: save.HouseData, + BookshelfData: save.BookshelfData, + GalleryData: save.GalleryData, + ToreData: save.ToreData, + GardenData: save.GardenData, + } + + // Time-gated rotating backup: include the previous compressed savedata + // in the transaction if enough time has elapsed since the last backup. if len(prevCompSave) > 0 { - maybeSaveBackup(s, save.CharID, prevCompSave) + if slot, ok := shouldBackup(s, save.CharID); ok { + params.BackupSlot = slot + params.BackupData = prevCompSave + } } - if err := s.server.charRepo.SaveCharacterData(save.CharID, save.compSave, save.HR, save.GR, save.Gender, save.WeaponType, save.WeaponID); err != nil { - s.logger.Error("Failed to update savedata", zap.Error(err), zap.Uint32("charID", save.CharID)) - return fmt.Errorf("save character data: %w", err) - } - - if err := s.server.charRepo.SaveHouseData(s.charID, save.HouseTier, save.HouseData, save.BookshelfData, save.GalleryData, save.ToreData, save.GardenData); err != nil { - s.logger.Error("Failed to update user binary house data", zap.Error(err)) - return fmt.Errorf("save house data: %w", err) + if err := s.server.charRepo.SaveCharacterDataAtomic(params); err != nil { + s.logger.Error("Failed to save character data atomically", + zap.Error(err), zap.Uint32("charID", save.CharID)) + return fmt.Errorf("atomic save: %w", err) } return nil } -// maybeSaveBackup checks whether enough time has elapsed since the last backup -// and, if so, writes the given compressed savedata into the next rotating slot. -// Errors are logged but do not block the save — backups are best-effort. -func maybeSaveBackup(s *Session, charID uint32, compSave []byte) { +// shouldBackup checks whether enough time has elapsed since the last backup +// and returns the target slot if a backup should be included in the save +// transaction. Returns (slot, true) if a backup is due, (0, false) otherwise. +func shouldBackup(s *Session, charID uint32) (int, bool) { lastBackup, err := s.server.charRepo.GetLastBackupTime(charID) if err != nil { s.logger.Warn("Failed to query last backup time, skipping backup", zap.Error(err), zap.Uint32("charID", charID)) - return + return 0, false } if time.Since(lastBackup) < saveBackupInterval { - return + return 0, false } - // Pick the next slot using a simple counter derived from the backup times. - // We rotate through slots 0, 1, 2 based on how many backups exist modulo - // the slot count. In practice this fills slots in order and then overwrites - // the oldest. slot := int(lastBackup.Unix()/int64(saveBackupInterval.Seconds())) % saveBackupSlots - - if err := s.server.charRepo.SaveBackup(charID, slot, compSave); err != nil { - s.logger.Warn("Failed to save backup", - zap.Error(err), zap.Uint32("charID", charID), zap.Int("slot", slot)) - return - } - - s.logger.Info("Savedata backup created", - zap.Uint32("charID", charID), zap.Int("slot", slot)) + return slot, true } func handleMsgMhfSexChanger(s *Session, p mhfpacket.MHFPacket) { diff --git a/server/channelserver/handlers_data.go b/server/channelserver/handlers_data.go index a84b91c17..6cf5ff3a1 100644 --- a/server/channelserver/handlers_data.go +++ b/server/channelserver/handlers_data.go @@ -28,6 +28,11 @@ const ( func handleMsgMhfSavedata(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfSavedata) + // Serialize saves for the same character to prevent concurrent operations + // from racing and defeating corruption detection. + unlock := s.server.charSaveLocks.Lock(s.charID) + defer unlock() + if len(pkt.RawDataPayload) > saveDataMaxCompressedPayload { s.logger.Warn("Savedata payload exceeds size limit", zap.Int("len", len(pkt.RawDataPayload)), diff --git a/server/channelserver/handlers_data_test.go b/server/channelserver/handlers_data_test.go index 210c3f5b9..19177a69d 100644 --- a/server/channelserver/handlers_data_test.go +++ b/server/channelserver/handlers_data_test.go @@ -2,8 +2,11 @@ package channelserver import ( "bytes" + "crypto/sha256" "encoding/binary" "fmt" + "sync" + "sync/atomic" "erupe-ce/common/byteframe" "erupe-ce/network" @@ -719,6 +722,132 @@ func TestBackupConstants(t *testing.T) { } } +// ============================================================================= +// Tier 2 protection tests +// ============================================================================= + +func TestSaveDataChecksumRoundTrip(t *testing.T) { + // Verify that a hash computed over decompressed data matches after + // a compress → decompress round trip (the checksum covers decompressed data). + original := make([]byte, 1000) + for i := range original { + original[i] = byte(i % 256) + } + + hash1 := sha256.Sum256(original) + + compressed, err := nullcomp.Compress(original) + if err != nil { + t.Fatalf("compress: %v", err) + } + + decompressed, err := nullcomp.Decompress(compressed) + if err != nil { + t.Fatalf("decompress: %v", err) + } + + hash2 := sha256.Sum256(decompressed) + + if hash1 != hash2 { + t.Error("checksum mismatch after compress/decompress round trip") + } +} + +func TestSaveDataChecksumDetectsCorruption(t *testing.T) { + data := []byte{0x01, 0x02, 0x03, 0x04, 0x05} + hash := sha256.Sum256(data) + + // Flip a bit + corrupted := make([]byte, len(data)) + copy(corrupted, data) + corrupted[2] ^= 0x01 + + corruptedHash := sha256.Sum256(corrupted) + + if bytes.Equal(hash[:], corruptedHash[:]) { + t.Error("checksum should differ after bit flip") + } +} + +func TestSaveAtomicParamsStructure(t *testing.T) { + params := SaveAtomicParams{ + CharID: 42, + CompSave: []byte{0x01}, + Hash: make([]byte, 32), + HR: 999, + GR: 100, + IsFemale: true, + WeaponType: 7, + WeaponID: 1234, + HouseTier: []byte{0x01, 0x00, 0x00, 0x00, 0x00}, + } + + if params.CharID != 42 { + t.Error("CharID mismatch") + } + if len(params.Hash) != 32 { + t.Errorf("hash should be 32 bytes, got %d", len(params.Hash)) + } + if params.BackupData != nil { + t.Error("BackupData should be nil when no backup requested") + } +} + +func TestCharacterLocks_SerializesSameCharacter(t *testing.T) { + var locks CharacterLocks + var counter int64 + + const goroutines = 100 + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + unlock := locks.Lock(1) // same charID + // Non-atomic increment — if locks don't work, race detector will catch it + v := atomic.LoadInt64(&counter) + atomic.StoreInt64(&counter, v+1) + unlock() + }() + } + wg.Wait() + + if atomic.LoadInt64(&counter) != goroutines { + t.Errorf("expected counter=%d, got %d", goroutines, atomic.LoadInt64(&counter)) + } +} + +func TestCharacterLocks_DifferentCharactersIndependent(t *testing.T) { + var locks CharacterLocks + var started, finished sync.WaitGroup + + started.Add(1) + finished.Add(2) + + // Lock char 1 + unlock1 := locks.Lock(1) + + // Goroutine trying to lock char 2 should succeed immediately + go func() { + defer finished.Done() + unlock2 := locks.Lock(2) // different char — should not block + started.Done() + unlock2() + }() + + // Wait for char 2 lock to succeed (proves independence) + started.Wait() + unlock1() + + // Goroutine for char 1 cleanup + go func() { + defer finished.Done() + }() + + finished.Wait() +} + // ============================================================================= // Tests consolidated from handlers_coverage4_test.go // ============================================================================= diff --git a/server/channelserver/repo_character.go b/server/channelserver/repo_character.go index 4de16bd32..6bc3a9b6b 100644 --- a/server/channelserver/repo_character.go +++ b/server/channelserver/repo_character.go @@ -2,11 +2,36 @@ package channelserver import ( "database/sql" + "fmt" "time" "github.com/jmoiron/sqlx" ) +// SaveAtomicParams bundles all fields needed for an atomic save transaction. +type SaveAtomicParams struct { + CharID uint32 + CompSave []byte + Hash []byte // SHA-256 of decompressed savedata + HR uint16 + GR uint16 + IsFemale bool + WeaponType uint8 + WeaponID uint16 + + // House data (written to user_binary) + HouseTier []byte + HouseData []byte + BookshelfData []byte + GalleryData []byte + ToreData []byte + GardenData []byte + + // Optional backup (nil means skip) + BackupSlot int + BackupData []byte +} + // CharacterRepository centralizes all database access for the characters table. type CharacterRepository struct { db *sqlx.DB @@ -238,6 +263,60 @@ func (r *CharacterRepository) GetLastBackupTime(charID uint32) (time.Time, error return t.Time, nil } +// SaveCharacterDataAtomic performs all save-related writes in a single +// database transaction. If any step fails, everything is rolled back. +func (r *CharacterRepository) SaveCharacterDataAtomic(params SaveAtomicParams) error { + tx, err := r.db.Beginx() + if err != nil { + return fmt.Errorf("begin tx: %w", err) + } + defer tx.Rollback() //nolint:errcheck // rollback is no-op after commit + + // 1. Save character data + hash + if _, err := tx.Exec( + `UPDATE characters SET savedata=$1, savedata_hash=$2, is_new_character=false, hr=$3, gr=$4, is_female=$5, weapon_type=$6, weapon_id=$7 WHERE id=$8`, + params.CompSave, params.Hash, params.HR, params.GR, params.IsFemale, params.WeaponType, params.WeaponID, params.CharID, + ); err != nil { + return fmt.Errorf("save character data: %w", err) + } + + // 2. Save house data + if _, err := tx.Exec( + `UPDATE user_binary SET house_tier=$1, house_data=$2, bookshelf=$3, gallery=$4, tore=$5, garden=$6 WHERE id=$7`, + params.HouseTier, params.HouseData, params.BookshelfData, params.GalleryData, params.ToreData, params.GardenData, params.CharID, + ); err != nil { + return fmt.Errorf("save house data: %w", err) + } + + // 3. Optional backup + if params.BackupData != nil { + if _, err := tx.Exec( + `INSERT INTO savedata_backups (char_id, slot, savedata, saved_at) + VALUES ($1, $2, $3, now()) + ON CONFLICT (char_id, slot) DO UPDATE SET savedata = $3, saved_at = now()`, + params.CharID, params.BackupSlot, params.BackupData, + ); err != nil { + return fmt.Errorf("save backup: %w", err) + } + } + + return tx.Commit() +} + +// LoadSaveDataWithHash reads the core save columns plus the integrity hash. +// The hash may be nil for characters saved before checksums were introduced. +func (r *CharacterRepository) LoadSaveDataWithHash(charID uint32) (uint32, []byte, bool, string, []byte, error) { + var id uint32 + var savedata []byte + var isNew bool + var name string + var hash []byte + err := r.db.QueryRow( + "SELECT id, savedata, is_new_character, name, savedata_hash FROM characters WHERE id = $1", charID, + ).Scan(&id, &savedata, &isNew, &name, &hash) + return id, savedata, isNew, name, hash, err +} + // FindByRastaID looks up name and id by rasta_id. func (r *CharacterRepository) FindByRastaID(rastaID int) (charID uint32, name string, err error) { err = r.db.QueryRow("SELECT name, id FROM characters WHERE rasta_id=$1", rastaID).Scan(&name, &charID) diff --git a/server/channelserver/repo_interfaces.go b/server/channelserver/repo_interfaces.go index 46571c0eb..6d2d9733f 100644 --- a/server/channelserver/repo_interfaces.go +++ b/server/channelserver/repo_interfaces.go @@ -41,6 +41,13 @@ type CharacterRepo interface { LoadSaveData(charID uint32) (uint32, []byte, bool, string, error) SaveBackup(charID uint32, slot int, data []byte) error GetLastBackupTime(charID uint32) (time.Time, error) + // SaveCharacterDataAtomic performs all save-related writes in a single + // database transaction: character data, house data, checksum, and + // optionally a backup snapshot. If any step fails, everything is rolled back. + SaveCharacterDataAtomic(params SaveAtomicParams) error + // LoadSaveDataWithHash loads savedata along with its stored SHA-256 hash. + // The hash may be nil for characters saved before checksums were introduced. + LoadSaveDataWithHash(charID uint32) (id uint32, savedata []byte, isNew bool, name string, hash []byte, err error) } // GuildRepo defines the contract for guild data access. diff --git a/server/channelserver/repo_mocks_test.go b/server/channelserver/repo_mocks_test.go index 449bce735..6617bc0ca 100644 --- a/server/channelserver/repo_mocks_test.go +++ b/server/channelserver/repo_mocks_test.go @@ -228,8 +228,12 @@ func (m *mockCharacterRepo) SaveHouseData(_ uint32, _ []byte, _, _, _, _, _ []by func (m *mockCharacterRepo) LoadSaveData(_ uint32) (uint32, []byte, bool, string, error) { return m.loadSaveDataID, m.loadSaveDataData, m.loadSaveDataNew, m.loadSaveDataName, m.loadSaveDataErr } -func (m *mockCharacterRepo) SaveBackup(_ uint32, _ int, _ []byte) error { return nil } -func (m *mockCharacterRepo) GetLastBackupTime(_ uint32) (time.Time, error) { return time.Time{}, nil } +func (m *mockCharacterRepo) SaveBackup(_ uint32, _ int, _ []byte) error { return nil } +func (m *mockCharacterRepo) GetLastBackupTime(_ uint32) (time.Time, error) { return time.Time{}, nil } +func (m *mockCharacterRepo) SaveCharacterDataAtomic(_ SaveAtomicParams) error { return nil } +func (m *mockCharacterRepo) LoadSaveDataWithHash(_ uint32) (uint32, []byte, bool, string, []byte, error) { + return m.loadSaveDataID, m.loadSaveDataData, m.loadSaveDataNew, m.loadSaveDataName, nil, m.loadSaveDataErr +} // --- mockGoocooRepo --- diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index b569d5dce..712c61b9c 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -96,6 +96,10 @@ type Server struct { userBinary *UserBinaryStore minidata *MinidataStore + // Per-character save locks prevent concurrent save operations for the + // same character from racing and defeating corruption detection. + charSaveLocks CharacterLocks + // Semaphore semaphoreLock sync.RWMutex semaphore map[string]*Semaphore diff --git a/server/migrations/sql/0008_savedata_hash.sql b/server/migrations/sql/0008_savedata_hash.sql new file mode 100644 index 000000000..04c19856b --- /dev/null +++ b/server/migrations/sql/0008_savedata_hash.sql @@ -0,0 +1,4 @@ +-- Add SHA-256 checksum column for savedata integrity verification. +-- Stored as 32 raw bytes (not hex). NULL means no hash computed yet +-- (backwards-compatible with existing data). +ALTER TABLE characters ADD COLUMN IF NOT EXISTS savedata_hash BYTEA;