diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 2ca4d8020d6bf6aea822628080a23252204ab299..a3138d6ef7813dd93756411de6caa4b69a892d6d 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -61,7 +61,8 @@ def parse_args():
   parser.add_option("-r", "--region", default="us-east-1",
       help="EC2 region zone to launch instances in")
   parser.add_option("-z", "--zone", default="",
-      help="Availability zone to launch instances in")
+      help="Availability zone to launch instances in, or 'all' to spread " +
+           "slaves across multiple")
   parser.add_option("-a", "--ami", default="latest",
       help="Amazon Machine Image ID to use, or 'latest' to use latest " +
            "available AMI (default: latest)")
@@ -217,17 +218,25 @@ def launch_cluster(conn, opts, cluster_name):
     # Launch spot instances with the requested price
     print ("Requesting %d slaves as spot instances with price $%.3f" %
            (opts.slaves, opts.spot_price))
-    slave_reqs = conn.request_spot_instances(
-        price = opts.spot_price,
-        image_id = opts.ami,
-        launch_group = "launch-group-%s" % cluster_name,
-        placement = opts.zone,
-        count = opts.slaves,
-        key_name = opts.key_pair,
-        security_groups = [slave_group],
-        instance_type = opts.instance_type,
-        block_device_map = block_map)
-    my_req_ids = [req.id for req in slave_reqs]
+    zones = get_zones(conn, opts)
+    num_zones = len(zones)
+    i = 0
+    my_req_ids = []
+    for zone in zones:
+      num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
+      slave_reqs = conn.request_spot_instances(
+          price = opts.spot_price,
+          image_id = opts.ami,
+          launch_group = "launch-group-%s" % cluster_name,
+          placement = zone,
+          count = num_slaves_this_zone,
+          key_name = opts.key_pair,
+          security_groups = [slave_group],
+          instance_type = opts.instance_type,
+          block_device_map = block_map)
+      my_req_ids += [req.id for req in slave_reqs]
+      i += 1
+    
     print "Waiting for spot instances to be granted..."
     try:
       while True:
@@ -262,20 +271,30 @@ def launch_cluster(conn, opts, cluster_name):
       sys.exit(0)
   else:
     # Launch non-spot instances
-    slave_res = image.run(key_name = opts.key_pair,
-                          security_groups = [slave_group],
-                          instance_type = opts.instance_type,
-                          placement = opts.zone,
-                          min_count = opts.slaves,
-                          max_count = opts.slaves,
-                          block_device_map = block_map)
-    slave_nodes = slave_res.instances
-    print "Launched slaves, regid = " + slave_res.id
+    zones = get_zones(conn, opts)
+    num_zones = len(zones)
+    i = 0
+    slave_nodes = []
+    for zone in zones:
+      num_slaves_this_zone = get_partition(opts.slaves, num_zones, i)
+      slave_res = image.run(key_name = opts.key_pair,
+                            security_groups = [slave_group],
+                            instance_type = opts.instance_type,
+                            placement = zone,
+                            min_count = num_slaves_this_zone,
+                            max_count = num_slaves_this_zone,
+                            block_device_map = block_map)
+      slave_nodes += slave_res.instances
+      print "Launched %d slaves in %s, regid = %s" % (num_slaves_this_zone,
+                                                      zone, slave_res.id)
+      i += 1
 
   # Launch masters
   master_type = opts.master_instance_type
   if master_type == "":
     master_type = opts.instance_type
+  if opts.zone == 'all':
+    opts.zone = random.choice(conn.get_all_zones()).name
   master_res = image.run(key_name = opts.key_pair,
                          security_groups = [master_group],
                          instance_type = master_type,
@@ -284,7 +303,7 @@ def launch_cluster(conn, opts, cluster_name):
                          max_count = 1,
                          block_device_map = block_map)
   master_nodes = master_res.instances
-  print "Launched master, regid = " + master_res.id
+  print "Launched master in %s, regid = %s" % (zone, master_res.id)
 
   zoo_nodes = []
 
@@ -474,6 +493,23 @@ def ssh(host, opts, command):
       (opts.identity_file, opts.user, host, command), shell=True)
 
 
+# Gets a list of zones to launch instances in
+def get_zones(conn, opts):
+  if opts.zone == 'all':
+    zones = [z.name for z in conn.get_all_zones()]
+  else:
+    zones = [opts.zone]
+  return zones
+
+
+# Gets the number of items in a partition
+def get_partition(total, num_partitions, current_partitions):
+  num_slaves_this_zone = total / num_partitions
+  if (total % num_partitions) - current_partitions > 0:
+    num_slaves_this_zone += 1
+  return num_slaves_this_zone
+
+
 def main():
   (opts, action, cluster_name) = parse_args()
   conn = boto.ec2.connect_to_region(opts.region)