From b98572c70ad3932381a55f23f82600d7e435d2eb Mon Sep 17 00:00:00 2001
From: Jey Kottalam <jey@cs.berkeley.edu>
Date: Wed, 3 Jul 2013 16:57:22 -0700
Subject: [PATCH] Generate new SSH key for the cluster, make "--identity-file"
 optional

---
 ec2/spark_ec2.py | 58 ++++++++++++++++++++++++++++++------------------
 1 file changed, 37 insertions(+), 21 deletions(-)

diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 75dd0ffa61..0858b126c5 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -104,11 +104,7 @@ def parse_args():
     parser.print_help()
     sys.exit(1)
   (action, cluster_name) = args
-  if opts.identity_file == None and action in ['launch', 'login', 'start']:
-    print >> stderr, ("ERROR: The -i or --identity-file argument is " +
-                      "required for " + action)
-    sys.exit(1)
-  
+
   # Boto config check
   # http://boto.cloudhackers.com/en/latest/boto_config_tut.html
   home_dir = os.getenv('HOME')
@@ -392,10 +388,18 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
 def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
   master = master_nodes[0].public_dns_name
   if deploy_ssh_key:
-    print "Copying SSH key %s to master..." % opts.identity_file
-    ssh(master, opts, 'mkdir -p ~/.ssh')
-    scp(master, opts, opts.identity_file, '~/.ssh/id_rsa')
-    ssh(master, opts, 'chmod 600 ~/.ssh/id_rsa')
+    print "Generating cluster's SSH key on master..."
+    key_setup = """
+      [ -f ~/.ssh/id_rsa ] ||
+        (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa &&
+         cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys)
+    """
+    ssh(master, opts, key_setup)
+    dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh'])
+    print "Transferring cluster's SSH key to slaves..."
+    for slave in slave_nodes:
+      print slave.public_dns_name
+      ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar)
 
   modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs', 
              'mapreduce', 'spark-standalone']
@@ -556,7 +560,9 @@ def stringify_command(parts):
 
 
 def ssh_args(opts):
-  parts = ['-o', 'StrictHostKeyChecking=no', '-i', opts.identity_file]
+  parts = ['-o', 'StrictHostKeyChecking=no']
+  if opts.identity_file is not None:
+    parts += ['-i', opts.identity_file]
   return parts
 
 
@@ -564,16 +570,6 @@ def ssh_command(opts):
   return ['ssh'] + ssh_args(opts)
 
 
-def scp_command(opts):
-  return ['scp', '-q'] + ssh_args(opts)
-
-
-# Copy a file to a given host through scp, throwing an exception if scp fails
-def scp(host, opts, local_file, dest_file):
-  subprocess.check_call(
-      scp_command(opts) + [local_file, "%s@%s:%s" % (opts.user, host, dest_file)])
-
-
 # Run a command on a host through ssh, retrying up to two times
 # and then throwing an exception if ssh continues to fail.
 def ssh(host, opts, command):
@@ -585,13 +581,33 @@ def ssh(host, opts, command):
     except subprocess.CalledProcessError as e:
       if (tries > 2):
         raise e
-      print "Couldn't connect to host {0}, waiting 30 seconds".format(e)
+      print "Error connecting to host, sleeping 30: {0}".format(e)
       time.sleep(30)
       tries = tries + 1
 
 
+def ssh_read(host, opts, command):
+  return subprocess.check_output(
+      ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)])
 
 
+def ssh_write(host, opts, command, input):
+  tries = 0
+  while True:
+    proc = subprocess.Popen(
+        ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)],
+        stdin=subprocess.PIPE)
+    proc.stdin.write(input)
+    proc.stdin.close()
+    if proc.wait() == 0:
+      break
+    elif (tries > 2):
+      raise RuntimeError("ssh_write error %s" % proc.returncode)
+    else:
+      print "Error connecting to host, sleeping 30"
+      time.sleep(30)
+      tries = tries + 1
+
 
 # Gets a list of zones to launch instances in
 def get_zones(conn, opts):
-- 
GitLab