#! /usr/bin/env python3

import sys, os, pathlib, threading, socketserver, socket, struct, copy, collections
sys.path.append(pathlib.Path(__file__).parent)
from lobby_client import Game, receive_list

class Lobby:
	default_motd = 'Welcome to the confederated Warzone 2100 lobby server'

	def __init__(self):
		self.lock = threading.Lock()
		self.motd = __class__.default_motd
		self.games = {}
		self.game_id = 0
		self.fqdn = None
		Relay = collections.namedtuple('Relay', ['address', 'port', 'confederated', 'bind_address'])
		self.relays = [
			Relay('warzone2100.retropaganda.info', 9990, True, ''),
			Relay('lobby.wz2100.net', 9990, False, os.environ.get('WARZONE2100_SOCKET_BIND_IPV4_ADDRESS', '')),
		]

	def handle_request(self, request_handler):
		def log(message): print(f'{request_handler.client_address[0]}:{request_handler.client_address[1]}: {message}', flush=True)

		log('connect')
		try:
			request_handler.request.settimeout(request_handler.server.timeout)

			def send(binary_format, *data): request_handler.request.sendall(struct.pack(binary_format, *data))
			
			def receive(size, received=None):
				if received is None: received = bytearray()
				size_left = size - len(received)
				while size_left:
					chunk = request_handler.request.recv(size_left)
					if not len(chunk): raise EOFError('socket connection broken')
					received += chunk
					size_left -= len(chunk)
				return received
			
			def unpack_game(game_id, data):
				game = Game.from_bytes(data)
				if not game.host_address: game.host_address = request_handler.client_address[0]
				if game.game_id != game_id: raise ValueError('game id changed')
				game.host2 = self.fqdn if self.fqdn is not None else request_handler.server.server_address[0]
				game.host3 = str(request_handler.server.server_address[1])
				with self.lock: self.games[game_id] = game
				return game

			def remove_game(game_id):
				with self.lock: del self.games[game_id]

			def send_games(games):
				send('!I', len(games))
				for game in games: request_handler.request.sendall(game.to_bytes())
			
			def send_status_and_motd(status):
				bytes = self.motd.encode()
				send(f'!2I{len(bytes)}s', status, len(bytes), bytes)

			def send_more_games(): send('!I', 0) # See explanation in warzone2100/lib/netplay/netplay.cpp function NETenumerateGames 
			
			def send_game_id():
				with self.lock:
					self.game_id += 1
					current_game_id = self.game_id
				send('!I', current_game_id)
				return current_game_id

			data = receive(5)
			if data == b'rela\0':
				log('rela')
				with self.lock: games = copy.deepcopy([game for game in self.games.values()])
				send_games(games)
			elif data == b'list\0':
				log('list')
				with self.lock: all_games = copy.deepcopy([game for game in self.games.values()])
				for relay in self.relays:
					if relay.address != self.fqdn:
						try:
							query_command = b'rela\0' if relay.confederated else b'list\0'
							status, _, games = receive_list(relay.address, relay.port, relay.bind_address, query_command)
							log(f'relay {relay} {status} {len(games)}')
							if status == 200:
								for game in games:
									game.host2 = relay.address
									game.host3 = str(relay.port)
									all_games.append(game)
						except Exception as e: print('relay error:', relay, e, file=sys.stderr)
				send_games(all_games)
				send_status_and_motd(200)
				send_more_games()
			elif data == b'gaId\0':
				log('host start')
				game_id = send_game_id()
				data = receive(5)
				if data == b'addg\0':
					data = receive(Game.binary_format.size)
					game = unpack_game(game_id, data)
					try:
						try:
							with socket.create_connection(
								(game.host_address, game.host_port),
								5,
								(request_handler.server.server_address[0], 0)
							) as join_socket:
								join_socket.sendall(b'\0' * (4 + 4)) # send NETCODE_VERSION_MAJOR and NETCODE_VERSION_MINOR with zero for special "join" case
								receive_size_left = 4 + 4 + 4 # response is "WZLR", version (1), game_id ... don't bother checking the values
								while receive_size_left: receive_size_left -= len(join_socket.recv(receive_size_left))
						except Exception as e:
							send_status_and_motd(400)
							raise e
						send_status_and_motd(200)
						request_handler.request.settimeout(30) # keep is sent every 25 seconds
						while True:
							data = receive(5)
							if data == b'keep\0': log('host keep')
							else:
								log('host update')
								receive(Game.binary_format.size, data)
								game = unpack_game(game_id, data)
								send_status_and_motd(200)
					finally: remove_game(game_id)
		finally: log('disconnect')

class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
	timeout = 5
	def __init__(self, lobby, *args):
		socketserver.ThreadingMixIn.__init__(self)
		socketserver.TCPServer.__init__(self, *args)
		self.lobby = lobby

class RequestHandler(socketserver.BaseRequestHandler):
	def handle(self): self.server.lobby.handle_request(self)

def main():
		import argparse
		parser = argparse.ArgumentParser()
		parser.add_argument('--fqdn', nargs='?', default=None)
		parser.add_argument('--address', nargs='?', default='')
		parser.add_argument('--port', nargs='?', type=int, default=9990)
		parser.add_argument('--motd', nargs='?', default=Lobby.default_motd)
		args = parser.parse_args()
		lobby = Lobby()
		lobby.fqdn = args.fqdn
		lobby.motd = args.motd
		with ThreadedTCPServer(lobby, (args.address, args.port), RequestHandler) as ss: ss.serve_forever()

if __name__ == '__main__': main()
