Add graceful shutdown to channel server

This commit is contained in:
Andrew Gutekanst
2020-01-13 18:36:55 -05:00
parent 30219b8bcf
commit 5f1d429c12
3 changed files with 43 additions and 7 deletions

View File

@@ -6,6 +6,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time"
"github.com/Andoryuuta/Erupe/config" "github.com/Andoryuuta/Erupe/config"
"github.com/Andoryuuta/Erupe/server/channelserver" "github.com/Andoryuuta/Erupe/server/channelserver"
@@ -113,7 +114,10 @@ func main() {
<-c <-c
logger.Info("Trying to shutdown gracefully.") logger.Info("Trying to shutdown gracefully.")
channelServer.Shutdown()
signServer.Shutdown() signServer.Shutdown()
entranceServer.Shutdown() entranceServer.Shutdown()
launcherServer.Shutdown() launcherServer.Shutdown()
time.Sleep(5 * time.Second)
} }

View File

@@ -26,8 +26,9 @@ type Server struct {
acceptConns chan net.Conn acceptConns chan net.Conn
deleteConns chan net.Conn deleteConns chan net.Conn
sessions map[net.Conn]*Session sessions map[net.Conn]*Session
listenAddr string
listener net.Listener // Listener that is created when Server.Start is called. listener net.Listener // Listener that is created when Server.Start is called.
isShuttingDown bool
} }
// NewServer creates a new Server type. // NewServer creates a new Server type.
@@ -57,13 +58,30 @@ func (s *Server) Start() error {
return nil return nil
} }
// Shutdown tries to shut down the server gracefully.
func (s *Server) Shutdown() {
s.Lock()
s.isShuttingDown = true
s.Unlock()
s.listener.Close()
close(s.acceptConns)
}
func (s *Server) acceptClients() { func (s *Server) acceptClients() {
for { for {
conn, err := s.listener.Accept() conn, err := s.listener.Accept()
if err != nil { if err != nil {
// TODO(Andoryuuta): Implement shutdown logic to end this goroutine cleanly here. s.Lock()
fmt.Println(err) shutdown := s.isShuttingDown
continue s.Unlock()
if shutdown {
break
} else {
s.logger.Warn("Error accepting client", zap.Error(err))
continue
}
} }
s.acceptConns <- conn s.acceptConns <- conn
} }
@@ -73,6 +91,17 @@ func (s *Server) manageSessions() {
for { for {
select { select {
case newConn := <-s.acceptConns: case newConn := <-s.acceptConns:
// Gracefully handle acceptConns channel closing.
if newConn == nil {
s.Lock()
shutdown := s.isShuttingDown
s.Unlock()
if shutdown {
return
}
}
session := NewSession(s, newConn) session := NewSession(s, newConn)
s.Lock() s.Lock()

View File

@@ -10,11 +10,13 @@ import (
"github.com/Andoryuuta/Erupe/network" "github.com/Andoryuuta/Erupe/network"
"github.com/Andoryuuta/Erupe/network/mhfpacket" "github.com/Andoryuuta/Erupe/network/mhfpacket"
"github.com/Andoryuuta/byteframe" "github.com/Andoryuuta/byteframe"
"go.uber.org/zap"
) )
// Session holds state for the channel server connection. // Session holds state for the channel server connection.
type Session struct { type Session struct {
sync.Mutex sync.Mutex
logger *zap.Logger
server *Server server *Server
rawConn net.Conn rawConn net.Conn
cryptConn *network.CryptConn cryptConn *network.CryptConn
@@ -23,6 +25,7 @@ type Session struct {
// NewSession creates a new Session type. // NewSession creates a new Session type.
func NewSession(server *Server, conn net.Conn) *Session { func NewSession(server *Server, conn net.Conn) *Session {
s := &Session{ s := &Session{
logger: server.logger,
server: server, server: server,
rawConn: conn, rawConn: conn,
cryptConn: network.NewCryptConn(conn), cryptConn: network.NewCryptConn(conn),
@@ -33,15 +36,15 @@ func NewSession(server *Server, conn net.Conn) *Session {
// Start starts the session packet read&handle loop. // Start starts the session packet read&handle loop.
func (s *Session) Start() { func (s *Session) Start() {
go func() { go func() {
fmt.Println("Channel server got connection!") s.logger.Info("Channel server got connection!")
// Unlike the sign and entrance server, // Unlike the sign and entrance server,
// the client DOES NOT initalize the channel connection with 8 NULL bytes. // the client DOES NOT initalize the channel connection with 8 NULL bytes.
for { for {
pkt, err := s.cryptConn.ReadPacket() pkt, err := s.cryptConn.ReadPacket()
if err != nil { if err != nil {
fmt.Println(err) s.logger.Warn("Error on channel server readpacket", zap.Error(err))
fmt.Println("Error on channel server readpacket")
return return
} }