Skip to content
Snippets Groups Projects
Commit 0e123715 authored by mdasari2's avatar mdasari2
Browse files

Master

parent de2cce13
No related branches found
No related tags found
No related merge requests found
Showing
with 802 additions and 0 deletions
# User Customize
checkpoint
# dataset
data
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
\ No newline at end of file
import logging
from abc import abstractmethod
from ..communication.gRPC.grpc_comm_manager import GRPCCommManager
from ..communication.message import Message
from ..communication.observer import Observer
class ClientManager(Observer):
def __init__(self, args, comm=None, rank=0, size=0, backend="MPI"):
self.args = args
self.size = size
self.rank = rank
self.backend = backend
HOST = "0.0.0.0"
PORT = 50000 + rank
self.com_manager = GRPCCommManager(
HOST, PORT, ip_config_path=args.grpc_ipconfig_path, client_id=rank, client_num=size - 1
)
self.com_manager.add_observer(self)
self.message_handler_dict = dict()
def run(self):
self.register_message_receive_handlers()
self.com_manager.handle_receive_message()
def get_sender_id(self):
return self.rank
def receive_message(self, msg_type, msg_params) -> None:
# logging.info("receive_message. rank_id = %d, msg_type = %s. msg_params = %s" % (
# self.rank, str(msg_type), str(msg_params.get_content())))
handler_callback_func = self.message_handler_dict[msg_type]
handler_callback_func(msg_params)
def send_message(self, message):
msg = Message()
msg.add(Message.MSG_ARG_KEY_TYPE, message.get_type())
msg.add(Message.MSG_ARG_KEY_SENDER, message.get_sender_id())
msg.add(Message.MSG_ARG_KEY_RECEIVER, message.get_receiver_id())
for key, value in message.get_params().items():
# logging.info("%s == %s" % (key, value))
msg.add(key, value)
self.com_manager.send_message(msg)
@abstractmethod
def register_message_receive_handlers(self) -> None:
pass
def register_message_receive_handler(self, msg_type, handler_callback_func):
self.message_handler_dict[msg_type] = handler_callback_func
def finish(self):
logging.info("__finish server")
self.com_manager.stop_receive_message()
from abc import abstractmethod
from .message import Message
from .observer import Observer
class BaseCommunicationManager(object):
@abstractmethod
def send_message(self, msg: Message):
pass
@abstractmethod
def add_observer(self, observer: Observer):
pass
@abstractmethod
def remove_observer(self, observer: Observer):
pass
@abstractmethod
def handle_receive_message(self):
pass
@abstractmethod
def stop_receive_message(self):
pass
python grpc_comm_manager.py 0 1 > ./grpc_client_log_0_1.txt 2>&1 &
python grpc_comm_manager.py 1 2 > ./grpc_client_log_1_2.txt 2>&1 &
python grpc_comm_manager.py 2 0 > ./grpc_client_log_2_0.txt 2>&1
\ No newline at end of file
import logging
import os
import threading
from concurrent import futures
from typing import List
import grpc
from ..gRPC import grpc_comm_manager_pb2_grpc, grpc_comm_manager_pb2
lock = threading.Lock()
from ...communication.base_com_manager import BaseCommunicationManager
from ...communication.message import Message
from ...communication.observer import Observer
from ...communication.gRPC.grpc_server import GRPCCOMMServicer
import csv
class GRPCCommManager(BaseCommunicationManager):
def __init__(self, host, port, ip_config_path, topic="fedml", client_id=0, client_num=0):
# host is the ip address of server
self.host = host
self.port = str(port)
self._topic = topic
self.client_id = client_id
self.client_num = client_num
self._observers: List[Observer] = []
if client_id == 0:
self.node_type = "server"
else:
self.node_type = "client"
self.opts = [
("grpc.max_send_message_length", 1000 * 1024 * 1024),
("grpc.max_receive_message_length", 1000 * 1024 * 1024),
("grpc.enable_http_proxy", 0),
]
self.grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=client_num), options=self.opts)
self.grpc_servicer = GRPCCOMMServicer(host, port, client_num, client_id)
grpc_comm_manager_pb2_grpc.add_gRPCCommManagerServicer_to_server(self.grpc_servicer, self.grpc_server)
logging.info(os.getcwd())
self.ip_config = self._build_ip_table(ip_config_path)
# starts a grpc_server on local machine using ip address "0.0.0.0"
self.grpc_server.add_insecure_port("{}:{}".format("0.0.0.0", port))
self.grpc_server.start()
self.is_running = True
print("server started. Listening on port " + str(port))
def send_message(self, msg: Message):
payload = msg.to_json()
receiver_id = msg.get_receiver_id()
PORT_BASE = 50000
# lookup ip of receiver from self.ip_config table
receiver_ip = self.ip_config[str(receiver_id)]
channel_url = "{}:{}".format(receiver_ip, str(PORT_BASE + receiver_id))
channel = grpc.insecure_channel(channel_url, options=self.opts)
stub = grpc_comm_manager_pb2_grpc.gRPCCommManagerStub(channel)
request = grpc_comm_manager_pb2.CommRequest()
logging.info("sending message to {}".format(channel_url))
request.client_id = self.client_id
request.message = payload
stub.sendMessage(request)
logging.debug("sent successfully")
channel.close()
def add_observer(self, observer: Observer):
self._observers.append(observer)
def remove_observer(self, observer: Observer):
self._observers.remove(observer)
def handle_receive_message(self):
thread = threading.Thread(target=self.message_handling_subroutine)
thread.start()
def message_handling_subroutine(self):
while self.is_running:
if self.grpc_servicer.message_q.qsize() > 0:
lock.acquire()
msg_params_string = self.grpc_servicer.message_q.get()
msg_params = Message()
msg_params.init_from_json_string(msg_params_string)
msg_type = msg_params.get_type()
for observer in self._observers:
observer.receive_message(msg_type, msg_params)
lock.release()
return
def stop_receive_message(self):
self.grpc_server.stop(None)
self.is_running = False
def notify(self, message: Message):
msg_type = message.get_type()
for observer in self._observers:
observer.receive_message(msg_type, message)
def _build_ip_table(self, path):
ip_config = dict()
with open(path, newline="") as csv_file:
csv_reader = csv.reader(csv_file)
# skip header line
next(csv_reader)
for row in csv_reader:
receiver_id, receiver_ip = row
ip_config[receiver_id] = receiver_ip
return ip_config
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: grpc_comm_manager.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='grpc_comm_manager.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x17grpc_comm_manager.proto\"1\n\x0b\x43ommRequest\x12\x11\n\tclient_id\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x43ommResponse\x12\x11\n\tclient_id\x18\x01 \x01(\x05\x12\x0f\n\x07message\x18\x02 \x01(\t2r\n\x0fgRPCCommManager\x12*\n\x0bsendMessage\x12\x0c.CommRequest\x1a\r.CommResponse\x12\x33\n\x14handleReceiveMessage\x12\x0c.CommRequest\x1a\r.CommResponseb\x06proto3'
)
_COMMREQUEST = _descriptor.Descriptor(
name='CommRequest',
full_name='CommRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='client_id', full_name='CommRequest.client_id', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='message', full_name='CommRequest.message', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=27,
serialized_end=76,
)
_COMMRESPONSE = _descriptor.Descriptor(
name='CommResponse',
full_name='CommResponse',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='client_id', full_name='CommResponse.client_id', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='message', full_name='CommResponse.message', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=78,
serialized_end=128,
)
DESCRIPTOR.message_types_by_name['CommRequest'] = _COMMREQUEST
DESCRIPTOR.message_types_by_name['CommResponse'] = _COMMRESPONSE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
CommRequest = _reflection.GeneratedProtocolMessageType('CommRequest', (_message.Message,), {
'DESCRIPTOR' : _COMMREQUEST,
'__module__' : 'grpc_comm_manager_pb2'
# @@protoc_insertion_point(class_scope:CommRequest)
})
_sym_db.RegisterMessage(CommRequest)
CommResponse = _reflection.GeneratedProtocolMessageType('CommResponse', (_message.Message,), {
'DESCRIPTOR' : _COMMRESPONSE,
'__module__' : 'grpc_comm_manager_pb2'
# @@protoc_insertion_point(class_scope:CommResponse)
})
_sym_db.RegisterMessage(CommResponse)
_GRPCCOMMMANAGER = _descriptor.ServiceDescriptor(
name='gRPCCommManager',
full_name='gRPCCommManager',
file=DESCRIPTOR,
index=0,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=130,
serialized_end=244,
methods=[
_descriptor.MethodDescriptor(
name='sendMessage',
full_name='gRPCCommManager.sendMessage',
index=0,
containing_service=None,
input_type=_COMMREQUEST,
output_type=_COMMRESPONSE,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
_descriptor.MethodDescriptor(
name='handleReceiveMessage',
full_name='gRPCCommManager.handleReceiveMessage',
index=1,
containing_service=None,
input_type=_COMMREQUEST,
output_type=_COMMRESPONSE,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
])
_sym_db.RegisterServiceDescriptor(_GRPCCOMMMANAGER)
DESCRIPTOR.services_by_name['gRPCCommManager'] = _GRPCCOMMMANAGER
# @@protoc_insertion_point(module_scope)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
from ..gRPC import grpc_comm_manager_pb2 as grpc__comm__manager__pb2
class gRPCCommManagerStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.sendMessage = channel.unary_unary(
'/gRPCCommManager/sendMessage',
request_serializer=grpc__comm__manager__pb2.CommRequest.SerializeToString,
response_deserializer=grpc__comm__manager__pb2.CommResponse.FromString,
)
self.handleReceiveMessage = channel.unary_unary(
'/gRPCCommManager/handleReceiveMessage',
request_serializer=grpc__comm__manager__pb2.CommRequest.SerializeToString,
response_deserializer=grpc__comm__manager__pb2.CommResponse.FromString,
)
class gRPCCommManagerServicer(object):
"""Missing associated documentation comment in .proto file."""
def sendMessage(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def handleReceiveMessage(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_gRPCCommManagerServicer_to_server(servicer, server):
rpc_method_handlers = {
'sendMessage': grpc.unary_unary_rpc_method_handler(
servicer.sendMessage,
request_deserializer=grpc__comm__manager__pb2.CommRequest.FromString,
response_serializer=grpc__comm__manager__pb2.CommResponse.SerializeToString,
),
'handleReceiveMessage': grpc.unary_unary_rpc_method_handler(
servicer.handleReceiveMessage,
request_deserializer=grpc__comm__manager__pb2.CommRequest.FromString,
response_serializer=grpc__comm__manager__pb2.CommResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'gRPCCommManager', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class gRPCCommManager(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def sendMessage(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/gRPCCommManager/sendMessage',
grpc__comm__manager__pb2.CommRequest.SerializeToString,
grpc__comm__manager__pb2.CommResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def handleReceiveMessage(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/gRPCCommManager/handleReceiveMessage',
grpc__comm__manager__pb2.CommRequest.SerializeToString,
grpc__comm__manager__pb2.CommResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
from ..gRPC import grpc_comm_manager_pb2, grpc_comm_manager_pb2_grpc
import queue
import threading
import logging
lock = threading.Lock()
class GRPCCOMMServicer(grpc_comm_manager_pb2_grpc.gRPCCommManagerServicer):
def __init__(self, host, port, client_num, client_id):
# host is the ip address of server
self.host = host
self.port = port
self.client_num = client_num
self.client_id = client_id
if self.client_id == 0:
self.node_type = "server"
else:
self.node_type = "client"
self.message_q = queue.Queue()
def sendMessage(self, request, context):
context_ip = context.peer().split(":")[1]
logging.info("client_{} got something from client_{} from ip address {}".format(
self.client_id,
request.client_id,
context_ip
))
response = grpc_comm_manager_pb2.CommResponse()
response.message = "message received"
lock.acquire()
self.message_q.put(request.message)
lock.release()
return response
def handleReceiveMessage(self, request, context):
pass
import csv
def build_ip_table(path):
ip_config = dict()
with open(path, newline="") as csv_file:
csv_reader = csv.reader(csv_file)
# skip header line
next(csv_reader)
for row in csv_reader:
receiver_id, receiver_ip = row
ip_config[receiver_id] = receiver_ip
return ip_config
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. grpc_comm_manager.proto
\ No newline at end of file
syntax = "proto3";
service gRPCCommManager {
rpc sendMessage (CommRequest) returns (CommResponse);
rpc handleReceiveMessage(CommRequest) returns (CommResponse);
}
message CommRequest {
int32 client_id = 1;
string message = 2;
}
message CommResponse {
int32 client_id = 1;
string message = 2;
}
import json
import sys
class Message(object):
MSG_ARG_KEY_OPERATION = "operation"
MSG_ARG_KEY_TYPE = "msg_type"
MSG_ARG_KEY_SENDER = "sender"
MSG_ARG_KEY_RECEIVER = "receiver"
MSG_OPERATION_SEND = "send"
MSG_OPERATION_RECEIVE = "receive"
MSG_OPERATION_BROADCAST = "broadcast"
MSG_OPERATION_REDUCE = "reduce"
MSG_ARG_KEY_MODEL_PARAMS = "model_params"
def __init__(self, type=0, sender_id=0, receiver_id=0):
self.type = type
self.sender_id = sender_id
self.receiver_id = receiver_id
self.msg_params = {}
self.msg_params[Message.MSG_ARG_KEY_TYPE] = type
self.msg_params[Message.MSG_ARG_KEY_SENDER] = sender_id
self.msg_params[Message.MSG_ARG_KEY_RECEIVER] = receiver_id
def init(self, msg_params):
self.msg_params = msg_params
def init_from_json_string(self, json_string):
self.msg_params = json.loads(json_string)
self.type = self.msg_params[Message.MSG_ARG_KEY_TYPE]
self.sender_id = self.msg_params[Message.MSG_ARG_KEY_SENDER]
self.receiver_id = self.msg_params[Message.MSG_ARG_KEY_RECEIVER]
# print("msg_params = " + str(self.msg_params))
def get_sender_id(self):
return self.sender_id
def get_receiver_id(self):
return self.receiver_id
def add_params(self, key, value):
self.msg_params[key] = value
def get_params(self):
return self.msg_params
def add(self, key, value):
self.msg_params[key] = value
def get(self, key):
return self.msg_params[key]
def get_type(self):
return self.msg_params[Message.MSG_ARG_KEY_TYPE]
def to_string(self):
return self.msg_params
def to_json(self):
json_string = json.dumps(self.msg_params)
print("json string size = " + str(sys.getsizeof(json_string)))
return json_string
def get_content(self):
print_dict = self.msg_params.copy()
msg_str = str(self.__to_msg_type_string()) + ": " + str(print_dict)
return msg_str
def __to_msg_type_string(self):
type = self.msg_params[Message.MSG_ARG_KEY_TYPE]
return type
from abc import ABC, abstractmethod
class Observer(ABC):
@abstractmethod
def receive_message(self, msg_type, msg_params) -> None:
pass
import logging
from abc import abstractmethod
import sys
from ..communication.gRPC.grpc_comm_manager import GRPCCommManager
from ..communication.observer import Observer
class ServerManager(Observer):
def __init__(self, args, comm=None, rank=0, size=0, backend="MPI"):
self.args = args
self.size = size
self.rank = rank
self.backend = backend
HOST = "0.0.0.0"
PORT = 50000 + rank
self.com_manager = GRPCCommManager(
HOST, PORT, ip_config_path=args.grpc_ipconfig_path, client_id=rank, client_num=size - 1
)
self.com_manager.add_observer(self)
self.message_handler_dict = dict()
def run(self):
self.register_message_receive_handlers()
self.com_manager.handle_receive_message()
def get_sender_id(self):
return self.rank
def receive_message(self, msg_type, msg_params) -> None:
# logging.info("receive_message. rank_id = %d, msg_type = %s. msg_params = %s" % (
# self.rank, str(msg_type), str(msg_params.get_content())))
handler_callback_func = self.message_handler_dict[msg_type]
handler_callback_func(msg_params)
def send_message(self, message):
self.com_manager.send_message(message)
@abstractmethod
def register_message_receive_handlers(self) -> None:
pass
def register_message_receive_handler(self, msg_type, handler_callback_func):
self.message_handler_dict[msg_type] = handler_callback_func
def finish(self):
logging.info("__finish server")
self.com_manager.stop_receive_message()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment