-
Jeffrey Zhang authoredJeffrey Zhang authored
SelectServer.py 5.44 KiB
import socket
import pickle
import select
import threading
import time
from dataclasses import replace
from Messages import Message
from TransactionDeliverer import TransactionDeliverer
class Client:
def __init__(self, td: TransactionDeliverer, sock: socket.socket, ip: str, port: str):
self.socket = sock
self.ip = ip
self.port = port
self.client_name = None
self.td = td
def handle(self):
message_len = b""
while len(message_len) < 2:
buffer = self.socket.recv(2 - len(message_len))
if len(buffer) == 0:
print("Less than 2 bytes")
return False
message_len += buffer
if len(message_len) < 2:
print(f"Less than two bytes - {self.client_name} - {len(message_len)}")
return False
message_len = int.from_bytes(message_len, 'big')
data = b""
second_or_higher = False
while len(data) < message_len:
buffer = self.socket.recv(message_len - len(data))
if second_or_higher:
print(len(data), message_len)
if len(buffer) == 0:
print("Less than message length")
return False
data += buffer
if len(data) < message_len:
second_or_higher = True
"""
if len(data) < message_len:
print(f"Less than msg len {len(data)} {message_len}")
return False
"""
obj: Message = pickle.loads(data)
if (obj.message_id in self.td.received_messages):
return True
self.td.received_messages.add(obj.message_id)
if not self.client_name:
self.client_name = obj.sender
with self.td.td_lock:
rebroadcast = obj.receive_action(self.td)
# Rebroadcast for reliable multicast
if rebroadcast:
reliable_multicast_messages = {client_name: replace(obj, receiver=client_name) for client_name in self.td.client_sockets.keys()
if client_name != self.client_name and client_name != self.td.node_name
}
# TODO This may have to be locked as multicast has logic for dropping clients
self.td.multicast(reliable_multicast_messages)
return True
class Server(threading.Thread):
def __init__(self, td: TransactionDeliverer, host, port, num_clients):
super().__init__()
self.logfile = None
self.td = td
self.host = host
self.port = port
# Used to wait until all clients are connected
self.num_clients = num_clients
self.end = threading.Event()
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setblocking(False)
self.clients = {}
# Includes both the server and client sockets,
# passed to select
self.waiting_sockets = set([self.server_socket])
def disconnect_client(self, client: Client, sock: socket.socket):
print(f"Disconnected client {client.client_name}")
self.waiting_sockets.remove(sock)
del self.clients[sock]
sock.shutdown(socket.SHUT_RDWR)
sock.close()
"""
if client.client_name in self.td.client_sockets:
td_socket = self.td.client_sockets[client.client_name]
del self.td.client_sockets[client.client_name]
td_socket.shutdown(socket.SHUT_RDWR)
td_socket.close()
"""
def handle_readable_socket(self, sock: socket.socket):
if sock == self.server_socket:
# Accept a new connection
(client_socket, (ip, port)) = sock.accept()
client = Client(self.td, client_socket, ip, port)
self.clients[client_socket] = client
self.waiting_sockets.add(client_socket)
#print(f"Accepted connection from ip: {ip}, port: {port}")
else:
# Read from node client
client = self.clients[sock]
if not client.handle():
# No data, disconnect
print(f"No data - {client.client_name} - disconnect")
self.disconnect_client(client, sock)
def handle_err_socket(self, sock: socket.socket):
client = self.clients[sock]
print(f"Client error - {client.client_name}")
self.disconnect_client(client, sock)
# Called by thread start
def run(self):
self.server_socket.bind((self.host, self.port))
self.server_socket.listen()
# Client sockets may not be set up yet, wait until they are
while not hasattr(self.td, "client_sockets") or len(self.td.client_sockets) < self.num_clients:
time.sleep(1)
while True:
# Select takes in: readable, writable, err, timeout
# Set timeout so we can check if thread is ended
readable_sockets, _, err_sockets = select.select(self.waiting_sockets, [], self.waiting_sockets, 1)
for s in readable_sockets:
self.handle_readable_socket(s)
for s in err_sockets:
self.handle_err_socket(s)
if self.end.is_set():
break
self.shutdown()
def shutdown(self):
for s in self.waiting_sockets:
s.shutdown(socket.SHUT_RDWR)
s.close()
self.server_socket.close()
self.td.shutdown()
if self.logfile is not None:
self.logfile.close()