Skip to content
Snippets Groups Projects
SelectServer.py 5.44 KiB
import socket
import pickle
import select
import threading
import time
from dataclasses import replace
from Messages import Message
from TransactionDeliverer import TransactionDeliverer

class Client:
    def __init__(self, td: TransactionDeliverer, sock: socket.socket, ip: str, port: str):
        self.socket = sock
        self.ip = ip
        self.port = port
        self.client_name = None
        self.td = td

    def handle(self):
        message_len = b""
        while len(message_len) < 2:
            buffer = self.socket.recv(2 - len(message_len))
            if len(buffer) == 0:
                print("Less than 2 bytes")
                return False
            message_len += buffer

        if len(message_len) < 2:
            print(f"Less than two bytes - {self.client_name} - {len(message_len)}")
            return False

        message_len = int.from_bytes(message_len, 'big')
        data = b""
        second_or_higher = False
        while len(data) < message_len:
            buffer = self.socket.recv(message_len - len(data))
            if second_or_higher:
                print(len(data), message_len)
            if len(buffer) == 0:
                print("Less than message length")
                return False
            data += buffer
            if len(data) < message_len:
                second_or_higher = True
        
        """
        if len(data) < message_len:
            print(f"Less than msg len {len(data)} {message_len}")
            return False
        """
        
        obj: Message = pickle.loads(data)

        if (obj.message_id in self.td.received_messages):
            return True
        self.td.received_messages.add(obj.message_id)

        if not self.client_name:
            self.client_name = obj.sender
 
        with self.td.td_lock:
            rebroadcast = obj.receive_action(self.td)

        # Rebroadcast for reliable multicast
        if rebroadcast:
            reliable_multicast_messages = {client_name: replace(obj, receiver=client_name) for client_name in self.td.client_sockets.keys()
                if client_name != self.client_name and client_name != self.td.node_name
            }

            # TODO This may have to be locked as multicast has logic for dropping clients
            self.td.multicast(reliable_multicast_messages)

        return True

class Server(threading.Thread):
    def __init__(self, td: TransactionDeliverer, host, port, num_clients):
        super().__init__()

        self.logfile = None

        self.td = td
        self.host = host
        self.port = port

        # Used to wait until all clients are connected
        self.num_clients = num_clients

        self.end = threading.Event()

        self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.server_socket.setblocking(False)
        self.clients = {}

        # Includes both the server and client sockets,
        # passed to select
        self.waiting_sockets = set([self.server_socket])

    def disconnect_client(self, client: Client, sock: socket.socket):
        print(f"Disconnected client {client.client_name}")
        self.waiting_sockets.remove(sock)
        del self.clients[sock]
        sock.shutdown(socket.SHUT_RDWR)
        sock.close()

        """
        if client.client_name in self.td.client_sockets:
            td_socket = self.td.client_sockets[client.client_name]
            del self.td.client_sockets[client.client_name]
            td_socket.shutdown(socket.SHUT_RDWR)
            td_socket.close()
        """

    def handle_readable_socket(self, sock: socket.socket):
        if sock == self.server_socket:
            # Accept a new connection
            (client_socket, (ip, port)) = sock.accept()
            client = Client(self.td, client_socket, ip, port)
            self.clients[client_socket] = client
            self.waiting_sockets.add(client_socket)
            #print(f"Accepted connection from ip: {ip}, port: {port}")
        else:
            # Read from node client
            client = self.clients[sock]
            if not client.handle():
                # No data, disconnect
                print(f"No data - {client.client_name} - disconnect")
                self.disconnect_client(client, sock)

    def handle_err_socket(self, sock: socket.socket):
        client = self.clients[sock]
        print(f"Client error - {client.client_name}")
        self.disconnect_client(client, sock)

    # Called by thread start
    def run(self):
        self.server_socket.bind((self.host, self.port))
        self.server_socket.listen()

        # Client sockets may not be set up yet, wait until they are
        while not hasattr(self.td, "client_sockets") or len(self.td.client_sockets) < self.num_clients:
            time.sleep(1)

        while True:
            # Select takes in: readable, writable, err, timeout
            # Set timeout so we can check if thread is ended
            readable_sockets, _, err_sockets = select.select(self.waiting_sockets, [], self.waiting_sockets, 1)

            for s in readable_sockets:
                self.handle_readable_socket(s)

            for s in err_sockets:
                self.handle_err_socket(s)

            if self.end.is_set():
                break

        self.shutdown()

    def shutdown(self):
        for s in self.waiting_sockets:
            s.shutdown(socket.SHUT_RDWR)
            s.close()

        self.server_socket.close()

        self.td.shutdown()

        if self.logfile is not None:
            self.logfile.close()