refactor: move cmd_id into net

This commit is contained in:
Naruse
2024-11-08 11:05:39 +08:00
parent 23f4f7bc63
commit bfbf1bb2ab
4 changed files with 25 additions and 21 deletions

View File

@@ -1,7 +1,8 @@
from game_server.config.log import Info from utils.logger import Info
from game_server.net.session import Session from game_server.net.session import Session
import asyncio import asyncio
class Gateway: class Gateway:
def __init__(self, server_ip, game_server_port) -> None: def __init__(self, server_ip, game_server_port) -> None:
self.server_ip = server_ip self.server_ip = server_ip
@@ -16,10 +17,9 @@ class Gateway:
async def start_server(self): async def start_server(self):
session = Session() session = Session()
server = await asyncio.start_server(session.handle_connection, self.server_ip, self.game_server_port) server = await asyncio.start_server(
session.handle_connection, self.server_ip, self.game_server_port
)
Info("Gateway listening...") Info("Gateway listening...")
async with server: async with server:
await server.serve_forever() await server.serve_forever()

View File

@@ -1,5 +1,5 @@
import struct import struct
from game_server.protocol.cmd_id import CmdID from game_server.net.cmd_id import CmdID
class Packet: class Packet:
def __init__(self, buf: bytes): def __init__(self, buf: bytes):

View File

@@ -1,5 +1,5 @@
from game_server.config.log import Error, Info from utils.logger import Error, Info
from game_server.protocol.cmd_id import CmdID from game_server.net.cmd_id import CmdID
from game_server.net.packet import Packet from game_server.net.packet import Packet
from lib import proto as protos from lib import proto as protos
import traceback import traceback
@@ -11,7 +11,8 @@ from game_server.game.player import Player
class Session: class Session:
player : Player player: Player
def __init__(self) -> None: def __init__(self) -> None:
self.writer = None self.writer = None
self.pending_notifies = [] self.pending_notifies = []
@@ -28,12 +29,12 @@ class Session:
except Exception as ex: except Exception as ex:
Error(f"Error in KeepAliveLoop: {ex}") Error(f"Error in KeepAliveLoop: {ex}")
break break
await asyncio.sleep(3) await asyncio.sleep(3)
async def handle_connection(self, reader, writer): async def handle_connection(self, reader, writer):
self.writer = writer self.writer = writer
addr = writer.get_extra_info('peername') addr = writer.get_extra_info("peername")
Info(f"Accepted connection from {addr}") Info(f"Accepted connection from {addr}")
prefix = bytes([0x01, 0x23, 0x45, 0x67]) prefix = bytes([0x01, 0x23, 0x45, 0x67])
@@ -41,7 +42,7 @@ class Session:
try: try:
while True: while True:
data = await reader.read(1 << 16) data = await reader.read(1 << 16)
if not data: if not data:
break break
@@ -59,7 +60,7 @@ class Session:
end = segment.find(suffix, start) end = segment.find(suffix, start)
if end == -1: if end == -1:
break break
end += len(suffix) end += len(suffix)
packets.append(segment[start:end]) packets.append(segment[start:end])
offset += end offset += end
@@ -82,14 +83,14 @@ class Session:
def create_packet(self, proto_message: betterproto.Message) -> Packet: def create_packet(self, proto_message: betterproto.Message) -> Packet:
return Packet.send_packet(proto_message) return Packet.send_packet(proto_message)
def is_valid_packet(self,data: bytes) -> bool: def is_valid_packet(self, data: bytes) -> bool:
hex_string = data.hex().upper() hex_string = data.hex().upper()
return hex_string.startswith("01234567") and hex_string.endswith("89ABCDEF") return hex_string.startswith("01234567") and hex_string.endswith("89ABCDEF")
def pending_notify(self, proto_message: betterproto.Message): def pending_notify(self, proto_message: betterproto.Message):
packet = Packet.send_packet(proto_message) packet = Packet.send_packet(proto_message)
self.pending_notifies.append(packet) self.pending_notifies.append(packet)
def send_pending_notifies_in_thread(self): def send_pending_notifies_in_thread(self):
thread = threading.Thread(target=self._run_send_pending_notifies) thread = threading.Thread(target=self._run_send_pending_notifies)
thread.start() thread.start()
@@ -99,18 +100,19 @@ class Session:
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.run_until_complete(self._send_pending_notifies()) loop.run_until_complete(self._send_pending_notifies())
loop.close() loop.close()
async def _send_pending_notifies(self): async def _send_pending_notifies(self):
for packet in self.pending_notifies: for packet in self.pending_notifies:
await self.send(packet) await self.send(packet)
self.pending_notifies.clear() self.pending_notifies.clear()
async def process_packet(self, packet : Packet): async def process_packet(self, packet: Packet):
if packet.cmd_id not in CmdID._value2member_map_: if packet.cmd_id not in CmdID._value2member_map_:
Error(f"CmdId {packet.cmd_id} not recognized!") Error(f"CmdId {packet.cmd_id} not recognized!")
return return
request_name = CmdID(packet.cmd_id).name request_name = CmdID(packet.cmd_id).name
if request_name == "KeepAliveNotify": return #await self.send(packet.send_packet(protos.KeepAliveNotify())) if request_name == "KeepAliveNotify":
return # await self.send(packet.send_packet(protos.KeepAliveNotify()))
try: try:
try: try:
req: betterproto.Message = getattr(protos, request_name)() req: betterproto.Message = getattr(protos, request_name)()
@@ -120,7 +122,9 @@ class Session:
try: try:
Info(f"RECV packet: {request_name} ({packet.cmd_id})") Info(f"RECV packet: {request_name} ({packet.cmd_id})")
handle_module = importlib.import_module(f"game_server.packet.handlers.{request_name}") handle_module = importlib.import_module(
f"game_server.packet.handlers.{request_name}"
)
handle_function = handle_module.handle handle_function = handle_module.handle
handle_result = await handle_function(self, req) handle_result = await handle_function(self, req)
if not handle_result: if not handle_result:
@@ -150,4 +154,4 @@ class Session:
Info(f"Sent packet: {packet_name} ({packet.cmd_id})") Info(f"Sent packet: {packet_name} ({packet.cmd_id})")
except Exception as ex: except Exception as ex:
Error(f"Failed to send {packet_name}: {ex}") Error(f"Failed to send {packet_name}: {ex}")
traceback.print_exc() traceback.print_exc()