diff --git a/server/channelserver/channel_isolation_test.go b/server/channelserver/channel_isolation_test.go new file mode 100644 index 000000000..158fca9a3 --- /dev/null +++ b/server/channelserver/channel_isolation_test.go @@ -0,0 +1,214 @@ +package channelserver + +import ( + "net" + "testing" + "time" + + _config "erupe-ce/config" + + "go.uber.org/zap" +) + +// createListeningTestServer creates a channel server that binds to a real TCP port. +// Port 0 lets the OS assign a free port. The server is automatically shut down +// when the test completes. +func createListeningTestServer(t *testing.T, id uint16) *Server { + t.Helper() + logger, _ := zap.NewDevelopment() + s := NewServer(&Config{ + ID: id, + Logger: logger, + ErupeConfig: &_config.Config{ + DebugOptions: _config.DebugOptions{ + LogOutboundMessages: false, + LogInboundMessages: false, + }, + }, + }) + s.Port = 0 // Let OS pick a free port + if err := s.Start(); err != nil { + t.Fatalf("channel %d failed to start: %v", id, err) + } + t.Cleanup(func() { + s.Shutdown() + time.Sleep(200 * time.Millisecond) // Let background goroutines and sessions exit. + }) + return s +} + +// listenerAddr returns the address the server is listening on. +func listenerAddr(s *Server) string { + return s.listener.Addr().String() +} + +// TestChannelIsolation_ShutdownDoesNotAffectOthers verifies that shutting down +// one channel server does not prevent other channels from accepting connections. +func TestChannelIsolation_ShutdownDoesNotAffectOthers(t *testing.T) { + ch1 := createListeningTestServer(t, 1) + ch2 := createListeningTestServer(t, 2) + ch3 := createListeningTestServer(t, 3) + + addr1 := listenerAddr(ch1) + addr2 := listenerAddr(ch2) + addr3 := listenerAddr(ch3) + + // Verify all three channels accept connections initially. + for _, addr := range []string{addr1, addr2, addr3} { + conn, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("initial connection to %s failed: %v", addr, err) + } + conn.Close() + } + + // Shut down channel 1. + ch1.Shutdown() + time.Sleep(50 * time.Millisecond) + + // Channel 1 should refuse connections. + _, err := net.DialTimeout("tcp", addr1, 500*time.Millisecond) + if err == nil { + t.Error("channel 1 should refuse connections after shutdown") + } + + // Channels 2 and 3 must still accept connections. + for _, tc := range []struct { + name string + addr string + }{ + {"channel 2", addr2}, + {"channel 3", addr3}, + } { + conn, err := net.DialTimeout("tcp", tc.addr, time.Second) + if err != nil { + t.Errorf("%s should still accept connections after channel 1 shutdown, got: %v", tc.name, err) + } else { + conn.Close() + } + } +} + +// TestChannelIsolation_ListenerCloseDoesNotAffectOthers simulates an unexpected +// listener failure (e.g. port conflict, OS-level error) on one channel and +// verifies other channels continue operating. +func TestChannelIsolation_ListenerCloseDoesNotAffectOthers(t *testing.T) { + ch1 := createListeningTestServer(t, 1) + ch2 := createListeningTestServer(t, 2) + + addr2 := listenerAddr(ch2) + + // Forcibly close channel 1's listener (simulating unexpected failure). + ch1.listener.Close() + time.Sleep(50 * time.Millisecond) + + // Channel 2 must still work. + conn, err := net.DialTimeout("tcp", addr2, time.Second) + if err != nil { + t.Fatalf("channel 2 should still accept connections after channel 1 listener closed: %v", err) + } + conn.Close() +} + +// TestChannelIsolation_SessionPanicDoesNotAffectChannel verifies that a panic +// inside a session handler is recovered and does not crash the channel server. +func TestChannelIsolation_SessionPanicDoesNotAffectChannel(t *testing.T) { + ch := createListeningTestServer(t, 1) + addr := listenerAddr(ch) + + // Connect a client that will trigger a session. + conn1, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("first connection failed: %v", err) + } + + // Send garbage data that will cause handlePacketGroup to hit the panic recovery. + // The session's defer/recover should catch it without killing the channel. + conn1.Write([]byte{0xFF, 0xFF, 0xFF, 0xFF}) + time.Sleep(100 * time.Millisecond) + conn1.Close() + time.Sleep(100 * time.Millisecond) + + // The channel should still accept new connections after the panic. + conn2, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("channel should still accept connections after session panic: %v", err) + } + conn2.Close() +} + +// TestChannelIsolation_CrossChannelRegistryAfterShutdown verifies that the +// channel registry handles a shut-down channel gracefully during cross-channel +// operations (search, find, disconnect). +func TestChannelIsolation_CrossChannelRegistryAfterShutdown(t *testing.T) { + channels := createTestChannels(3) + reg := NewLocalChannelRegistry(channels) + + // Add sessions to all channels. + for i, ch := range channels { + conn := &mockConn{} + sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player") + sess.stage = NewStage("sl1Ns200p0a0u0") + ch.Lock() + ch.sessions[conn] = sess + ch.Unlock() + } + + // Simulate channel 1 shutting down by marking it and clearing sessions. + channels[0].Lock() + channels[0].isShuttingDown = true + channels[0].sessions = make(map[net.Conn]*Session) + channels[0].Unlock() + + // Registry operations should still work for remaining channels. + found := reg.FindSessionByCharID(2) + if found == nil { + t.Error("FindSessionByCharID(2) should find session on channel 2") + } + + found = reg.FindSessionByCharID(3) + if found == nil { + t.Error("FindSessionByCharID(3) should find session on channel 3") + } + + // Session from shut-down channel should not be found. + found = reg.FindSessionByCharID(1) + if found != nil { + t.Error("FindSessionByCharID(1) should not find session on shut-down channel") + } + + // SearchSessions should return only sessions from live channels. + results := reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 10) + if len(results) != 2 { + t.Errorf("SearchSessions should return 2 results from live channels, got %d", len(results)) + } +} + +// TestChannelIsolation_IndependentStages verifies that stages are per-channel +// and one channel's stages don't leak into another. +func TestChannelIsolation_IndependentStages(t *testing.T) { + channels := createTestChannels(2) + + stageName := "sl1Qs999p0a0u42" + + // Add stage only to channel 1. + channels[0].stagesLock.Lock() + channels[0].stages[stageName] = NewStage(stageName) + channels[0].stagesLock.Unlock() + + // Channel 1 should have the stage. + channels[0].stagesLock.RLock() + _, ok1 := channels[0].stages[stageName] + channels[0].stagesLock.RUnlock() + if !ok1 { + t.Error("channel 1 should have the stage") + } + + // Channel 2 should NOT have the stage. + channels[1].stagesLock.RLock() + _, ok2 := channels[1].stages[stageName] + channels[1].stagesLock.RUnlock() + if ok2 { + t.Error("channel 2 should not have channel 1's stage") + } +} diff --git a/server/channelserver/handlers_session.go b/server/channelserver/handlers_session.go index 1944a7902..ab39c6b95 100644 --- a/server/channelserver/handlers_session.go +++ b/server/channelserver/handlers_session.go @@ -293,14 +293,16 @@ func logoutPlayer(s *Session) { } // Update sign sessions and server player count - _, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token) - if err != nil { - panic(err) - } + if s.server.db != nil { + _, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token) + if err != nil { + s.logger.Error("Failed to clear sign session", zap.Error(err)) + } - _, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID) - if err != nil { - panic(err) + _, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID) + if err != nil { + s.logger.Error("Failed to update player count", zap.Error(err)) + } } if s.stage == nil { diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index f042c6109..bfd22414d 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -50,6 +50,7 @@ type Server struct { sessions map[net.Conn]*Session listener net.Listener // Listener that is created when Server.Start is called. isShuttingDown bool + done chan struct{} // Closed on Shutdown to wake background goroutines. stagesLock sync.RWMutex stages map[string]*Stage @@ -91,6 +92,7 @@ func NewServer(config *Config) *Server { erupeConfig: config.ErupeConfig, acceptConns: make(chan net.Conn), deleteConns: make(chan net.Conn), + done: make(chan struct{}), sessions: make(map[net.Conn]*Session), stages: make(map[string]*Stage), userBinaryParts: make(map[userBinaryPartID][]byte), @@ -156,19 +158,23 @@ func (s *Server) Start() error { return nil } -// Shutdown tries to shut down the server gracefully. +// Shutdown tries to shut down the server gracefully. Safe to call multiple times. func (s *Server) Shutdown() { s.Lock() + alreadyShutDown := s.isShuttingDown s.isShuttingDown = true s.Unlock() + if alreadyShutDown { + return + } + + close(s.done) + if s.listener != nil { _ = s.listener.Close() } - if s.acceptConns != nil { - close(s.acceptConns) - } } func (s *Server) acceptClients() { @@ -186,25 +192,21 @@ func (s *Server) acceptClients() { continue } } - s.acceptConns <- conn + select { + case s.acceptConns <- conn: + case <-s.done: + _ = conn.Close() + return + } } } func (s *Server) manageSessions() { for { select { + case <-s.done: + return case newConn := <-s.acceptConns: - // Gracefully handle acceptConns channel closing. - if newConn == nil { - s.Lock() - shutdown := s.isShuttingDown - s.Unlock() - - if shutdown { - return - } - } - session := NewSession(s, newConn) s.Lock() @@ -236,15 +238,28 @@ func (s *Server) getObjectId() uint16 { } func (s *Server) invalidateSessions() { - for !s.isShuttingDown { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-s.done: + return + case <-ticker.C: + } + s.Lock() + var timedOut []*Session for _, sess := range s.sessions { if time.Since(sess.lastPacket) > time.Second*time.Duration(30) { - s.logger.Info("session timeout", zap.String("Name", sess.Name)) - logoutPlayer(sess) + timedOut = append(timedOut, sess) } } - time.Sleep(time.Second * 10) + s.Unlock() + + for _, sess := range timedOut { + s.logger.Info("session timeout", zap.String("Name", sess.Name)) + logoutPlayer(sess) + } } } diff --git a/server/channelserver/sys_session.go b/server/channelserver/sys_session.go index 294d470ab..b30190aec 100644 --- a/server/channelserver/sys_session.go +++ b/server/channelserver/sys_session.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "encoding/hex" "erupe-ce/common/mhfcourse" - _config "erupe-ce/config" "fmt" "io" "net" @@ -172,7 +171,7 @@ func (s *Session) sendLoop() { s.logger.Warn("Failed to send packet", zap.Error(err)) } } - time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond) + time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond) } } @@ -215,7 +214,7 @@ func (s *Session) recvLoop() { return } s.handlePacketGroup(pkt) - time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond) + time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond) } }