mirror of
https://github.com/Mezeporta/Erupe.git
synced 2026-03-21 23:22:34 +01:00
fix(channelserver): eliminate data races in shutdown and session lifecycle
The channel server had several concurrency issues found by the race detector during isolation testing: - acceptClients could send on a closed acceptConns channel during shutdown, causing a panic. Replace close(acceptConns) with a done channel and select-based shutdown signaling in both acceptClients and manageSessions. - invalidateSessions read isShuttingDown and iterated sessions without holding the lock. Rewrite with ticker + done channel select and snapshot sessions under lock before processing timeouts. - sendLoop/recvLoop accessed global _config.ErupeConfig.LoopDelay which races with tests modifying the global. Use the per-server erupeConfig instead. - logoutPlayer panicked on DB errors and crashed on nil DB (no-db test scenarios). Guard with nil check and log errors instead. - Shutdown was not idempotent, double-calling caused double-close panic on done channel. Add 5 channel isolation tests verifying independent shutdown, listener failure, session panic recovery, cross-channel registry after shutdown, and stage isolation.
This commit is contained in:
214
server/channelserver/channel_isolation_test.go
Normal file
214
server/channelserver/channel_isolation_test.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
package channelserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_config "erupe-ce/config"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createListeningTestServer creates a channel server that binds to a real TCP port.
|
||||||
|
// Port 0 lets the OS assign a free port. The server is automatically shut down
|
||||||
|
// when the test completes.
|
||||||
|
func createListeningTestServer(t *testing.T, id uint16) *Server {
|
||||||
|
t.Helper()
|
||||||
|
logger, _ := zap.NewDevelopment()
|
||||||
|
s := NewServer(&Config{
|
||||||
|
ID: id,
|
||||||
|
Logger: logger,
|
||||||
|
ErupeConfig: &_config.Config{
|
||||||
|
DebugOptions: _config.DebugOptions{
|
||||||
|
LogOutboundMessages: false,
|
||||||
|
LogInboundMessages: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
s.Port = 0 // Let OS pick a free port
|
||||||
|
if err := s.Start(); err != nil {
|
||||||
|
t.Fatalf("channel %d failed to start: %v", id, err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
s.Shutdown()
|
||||||
|
time.Sleep(200 * time.Millisecond) // Let background goroutines and sessions exit.
|
||||||
|
})
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// listenerAddr returns the address the server is listening on.
|
||||||
|
func listenerAddr(s *Server) string {
|
||||||
|
return s.listener.Addr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestChannelIsolation_ShutdownDoesNotAffectOthers verifies that shutting down
|
||||||
|
// one channel server does not prevent other channels from accepting connections.
|
||||||
|
func TestChannelIsolation_ShutdownDoesNotAffectOthers(t *testing.T) {
|
||||||
|
ch1 := createListeningTestServer(t, 1)
|
||||||
|
ch2 := createListeningTestServer(t, 2)
|
||||||
|
ch3 := createListeningTestServer(t, 3)
|
||||||
|
|
||||||
|
addr1 := listenerAddr(ch1)
|
||||||
|
addr2 := listenerAddr(ch2)
|
||||||
|
addr3 := listenerAddr(ch3)
|
||||||
|
|
||||||
|
// Verify all three channels accept connections initially.
|
||||||
|
for _, addr := range []string{addr1, addr2, addr3} {
|
||||||
|
conn, err := net.DialTimeout("tcp", addr, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("initial connection to %s failed: %v", addr, err)
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shut down channel 1.
|
||||||
|
ch1.Shutdown()
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Channel 1 should refuse connections.
|
||||||
|
_, err := net.DialTimeout("tcp", addr1, 500*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("channel 1 should refuse connections after shutdown")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Channels 2 and 3 must still accept connections.
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
}{
|
||||||
|
{"channel 2", addr2},
|
||||||
|
{"channel 3", addr3},
|
||||||
|
} {
|
||||||
|
conn, err := net.DialTimeout("tcp", tc.addr, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%s should still accept connections after channel 1 shutdown, got: %v", tc.name, err)
|
||||||
|
} else {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestChannelIsolation_ListenerCloseDoesNotAffectOthers simulates an unexpected
|
||||||
|
// listener failure (e.g. port conflict, OS-level error) on one channel and
|
||||||
|
// verifies other channels continue operating.
|
||||||
|
func TestChannelIsolation_ListenerCloseDoesNotAffectOthers(t *testing.T) {
|
||||||
|
ch1 := createListeningTestServer(t, 1)
|
||||||
|
ch2 := createListeningTestServer(t, 2)
|
||||||
|
|
||||||
|
addr2 := listenerAddr(ch2)
|
||||||
|
|
||||||
|
// Forcibly close channel 1's listener (simulating unexpected failure).
|
||||||
|
ch1.listener.Close()
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Channel 2 must still work.
|
||||||
|
conn, err := net.DialTimeout("tcp", addr2, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("channel 2 should still accept connections after channel 1 listener closed: %v", err)
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestChannelIsolation_SessionPanicDoesNotAffectChannel verifies that a panic
|
||||||
|
// inside a session handler is recovered and does not crash the channel server.
|
||||||
|
func TestChannelIsolation_SessionPanicDoesNotAffectChannel(t *testing.T) {
|
||||||
|
ch := createListeningTestServer(t, 1)
|
||||||
|
addr := listenerAddr(ch)
|
||||||
|
|
||||||
|
// Connect a client that will trigger a session.
|
||||||
|
conn1, err := net.DialTimeout("tcp", addr, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first connection failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send garbage data that will cause handlePacketGroup to hit the panic recovery.
|
||||||
|
// The session's defer/recover should catch it without killing the channel.
|
||||||
|
conn1.Write([]byte{0xFF, 0xFF, 0xFF, 0xFF})
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
conn1.Close()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// The channel should still accept new connections after the panic.
|
||||||
|
conn2, err := net.DialTimeout("tcp", addr, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("channel should still accept connections after session panic: %v", err)
|
||||||
|
}
|
||||||
|
conn2.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestChannelIsolation_CrossChannelRegistryAfterShutdown verifies that the
|
||||||
|
// channel registry handles a shut-down channel gracefully during cross-channel
|
||||||
|
// operations (search, find, disconnect).
|
||||||
|
func TestChannelIsolation_CrossChannelRegistryAfterShutdown(t *testing.T) {
|
||||||
|
channels := createTestChannels(3)
|
||||||
|
reg := NewLocalChannelRegistry(channels)
|
||||||
|
|
||||||
|
// Add sessions to all channels.
|
||||||
|
for i, ch := range channels {
|
||||||
|
conn := &mockConn{}
|
||||||
|
sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player")
|
||||||
|
sess.stage = NewStage("sl1Ns200p0a0u0")
|
||||||
|
ch.Lock()
|
||||||
|
ch.sessions[conn] = sess
|
||||||
|
ch.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate channel 1 shutting down by marking it and clearing sessions.
|
||||||
|
channels[0].Lock()
|
||||||
|
channels[0].isShuttingDown = true
|
||||||
|
channels[0].sessions = make(map[net.Conn]*Session)
|
||||||
|
channels[0].Unlock()
|
||||||
|
|
||||||
|
// Registry operations should still work for remaining channels.
|
||||||
|
found := reg.FindSessionByCharID(2)
|
||||||
|
if found == nil {
|
||||||
|
t.Error("FindSessionByCharID(2) should find session on channel 2")
|
||||||
|
}
|
||||||
|
|
||||||
|
found = reg.FindSessionByCharID(3)
|
||||||
|
if found == nil {
|
||||||
|
t.Error("FindSessionByCharID(3) should find session on channel 3")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session from shut-down channel should not be found.
|
||||||
|
found = reg.FindSessionByCharID(1)
|
||||||
|
if found != nil {
|
||||||
|
t.Error("FindSessionByCharID(1) should not find session on shut-down channel")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchSessions should return only sessions from live channels.
|
||||||
|
results := reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 10)
|
||||||
|
if len(results) != 2 {
|
||||||
|
t.Errorf("SearchSessions should return 2 results from live channels, got %d", len(results))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestChannelIsolation_IndependentStages verifies that stages are per-channel
|
||||||
|
// and one channel's stages don't leak into another.
|
||||||
|
func TestChannelIsolation_IndependentStages(t *testing.T) {
|
||||||
|
channels := createTestChannels(2)
|
||||||
|
|
||||||
|
stageName := "sl1Qs999p0a0u42"
|
||||||
|
|
||||||
|
// Add stage only to channel 1.
|
||||||
|
channels[0].stagesLock.Lock()
|
||||||
|
channels[0].stages[stageName] = NewStage(stageName)
|
||||||
|
channels[0].stagesLock.Unlock()
|
||||||
|
|
||||||
|
// Channel 1 should have the stage.
|
||||||
|
channels[0].stagesLock.RLock()
|
||||||
|
_, ok1 := channels[0].stages[stageName]
|
||||||
|
channels[0].stagesLock.RUnlock()
|
||||||
|
if !ok1 {
|
||||||
|
t.Error("channel 1 should have the stage")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Channel 2 should NOT have the stage.
|
||||||
|
channels[1].stagesLock.RLock()
|
||||||
|
_, ok2 := channels[1].stages[stageName]
|
||||||
|
channels[1].stagesLock.RUnlock()
|
||||||
|
if ok2 {
|
||||||
|
t.Error("channel 2 should not have channel 1's stage")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -293,14 +293,16 @@ func logoutPlayer(s *Session) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update sign sessions and server player count
|
// Update sign sessions and server player count
|
||||||
_, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token)
|
if s.server.db != nil {
|
||||||
if err != nil {
|
_, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token)
|
||||||
panic(err)
|
if err != nil {
|
||||||
}
|
s.logger.Error("Failed to clear sign session", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
_, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID)
|
_, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
s.logger.Error("Failed to update player count", zap.Error(err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.stage == nil {
|
if s.stage == nil {
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ type Server struct {
|
|||||||
sessions map[net.Conn]*Session
|
sessions map[net.Conn]*Session
|
||||||
listener net.Listener // Listener that is created when Server.Start is called.
|
listener net.Listener // Listener that is created when Server.Start is called.
|
||||||
isShuttingDown bool
|
isShuttingDown bool
|
||||||
|
done chan struct{} // Closed on Shutdown to wake background goroutines.
|
||||||
|
|
||||||
stagesLock sync.RWMutex
|
stagesLock sync.RWMutex
|
||||||
stages map[string]*Stage
|
stages map[string]*Stage
|
||||||
@@ -91,6 +92,7 @@ func NewServer(config *Config) *Server {
|
|||||||
erupeConfig: config.ErupeConfig,
|
erupeConfig: config.ErupeConfig,
|
||||||
acceptConns: make(chan net.Conn),
|
acceptConns: make(chan net.Conn),
|
||||||
deleteConns: make(chan net.Conn),
|
deleteConns: make(chan net.Conn),
|
||||||
|
done: make(chan struct{}),
|
||||||
sessions: make(map[net.Conn]*Session),
|
sessions: make(map[net.Conn]*Session),
|
||||||
stages: make(map[string]*Stage),
|
stages: make(map[string]*Stage),
|
||||||
userBinaryParts: make(map[userBinaryPartID][]byte),
|
userBinaryParts: make(map[userBinaryPartID][]byte),
|
||||||
@@ -156,19 +158,23 @@ func (s *Server) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown tries to shut down the server gracefully.
|
// Shutdown tries to shut down the server gracefully. Safe to call multiple times.
|
||||||
func (s *Server) Shutdown() {
|
func (s *Server) Shutdown() {
|
||||||
s.Lock()
|
s.Lock()
|
||||||
|
alreadyShutDown := s.isShuttingDown
|
||||||
s.isShuttingDown = true
|
s.isShuttingDown = true
|
||||||
s.Unlock()
|
s.Unlock()
|
||||||
|
|
||||||
|
if alreadyShutDown {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(s.done)
|
||||||
|
|
||||||
if s.listener != nil {
|
if s.listener != nil {
|
||||||
_ = s.listener.Close()
|
_ = s.listener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.acceptConns != nil {
|
|
||||||
close(s.acceptConns)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) acceptClients() {
|
func (s *Server) acceptClients() {
|
||||||
@@ -186,25 +192,21 @@ func (s *Server) acceptClients() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.acceptConns <- conn
|
select {
|
||||||
|
case s.acceptConns <- conn:
|
||||||
|
case <-s.done:
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) manageSessions() {
|
func (s *Server) manageSessions() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
return
|
||||||
case newConn := <-s.acceptConns:
|
case newConn := <-s.acceptConns:
|
||||||
// Gracefully handle acceptConns channel closing.
|
|
||||||
if newConn == nil {
|
|
||||||
s.Lock()
|
|
||||||
shutdown := s.isShuttingDown
|
|
||||||
s.Unlock()
|
|
||||||
|
|
||||||
if shutdown {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
session := NewSession(s, newConn)
|
session := NewSession(s, newConn)
|
||||||
|
|
||||||
s.Lock()
|
s.Lock()
|
||||||
@@ -236,15 +238,28 @@ func (s *Server) getObjectId() uint16 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) invalidateSessions() {
|
func (s *Server) invalidateSessions() {
|
||||||
for !s.isShuttingDown {
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.done:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Lock()
|
||||||
|
var timedOut []*Session
|
||||||
for _, sess := range s.sessions {
|
for _, sess := range s.sessions {
|
||||||
if time.Since(sess.lastPacket) > time.Second*time.Duration(30) {
|
if time.Since(sess.lastPacket) > time.Second*time.Duration(30) {
|
||||||
s.logger.Info("session timeout", zap.String("Name", sess.Name))
|
timedOut = append(timedOut, sess)
|
||||||
logoutPlayer(sess)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second * 10)
|
s.Unlock()
|
||||||
|
|
||||||
|
for _, sess := range timedOut {
|
||||||
|
s.logger.Info("session timeout", zap.String("Name", sess.Name))
|
||||||
|
logoutPlayer(sess)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"erupe-ce/common/mhfcourse"
|
"erupe-ce/common/mhfcourse"
|
||||||
_config "erupe-ce/config"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -172,7 +171,7 @@ func (s *Session) sendLoop() {
|
|||||||
s.logger.Warn("Failed to send packet", zap.Error(err))
|
s.logger.Warn("Failed to send packet", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond)
|
time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +214,7 @@ func (s *Session) recvLoop() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.handlePacketGroup(pkt)
|
s.handlePacketGroup(pkt)
|
||||||
time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond)
|
time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user