Skip to content
Snippets Groups Projects
Commit 02208a17 authored by Tor Myklebust's avatar Tor Myklebust
Browse files

Initial weights in Scala are ones; do that too. Also fix some errors.

parent 4e821390
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
# #
from numpy import ndarray, copyto, float64, int64, int32, zeros, array_equal, array, dot, shape from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
from pyspark import SparkContext from pyspark import SparkContext
# Double vector format: # Double vector format:
...@@ -143,7 +143,7 @@ def _linear_predictor_typecheck(x, coeffs): ...@@ -143,7 +143,7 @@ def _linear_predictor_typecheck(x, coeffs):
elif (type(x) == RDD): elif (type(x) == RDD):
raise RuntimeError("Bulk predict not yet supported.") raise RuntimeError("Bulk predict not yet supported.")
else: else:
raise TypeError("Argument of type " + type(x) + " unsupported") raise TypeError("Argument of type " + type(x).__name__ + " unsupported")
def _get_unmangled_rdd(data, serializer): def _get_unmangled_rdd(data, serializer):
dataBytes = data.map(serializer) dataBytes = data.map(serializer)
...@@ -182,11 +182,11 @@ def _get_initial_weights(initial_weights, data): ...@@ -182,11 +182,11 @@ def _get_initial_weights(initial_weights, data):
initial_weights = data.first() initial_weights = data.first()
if type(initial_weights) != ndarray: if type(initial_weights) != ndarray:
raise TypeError("At least one data element has type " raise TypeError("At least one data element has type "
+ type(initial_weights) + " which is not ndarray") + type(initial_weights).__name__ + " which is not ndarray")
if initial_weights.ndim != 1: if initial_weights.ndim != 1:
raise TypeError("At least one data element has " raise TypeError("At least one data element has "
+ initial_weights.ndim + " dimensions, which is not 1") + initial_weights.ndim + " dimensions, which is not 1")
initial_weights = zeros([initial_weights.shape[0] - 1]) initial_weights = ones([initial_weights.shape[0] - 1])
return initial_weights return initial_weights
# train_func should take two parameters, namely data and initial_weights, and # train_func should take two parameters, namely data and initial_weights, and
...@@ -200,10 +200,10 @@ def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): ...@@ -200,10 +200,10 @@ def _regression_train_wrapper(sc, train_func, klass, data, initial_weights):
raise RuntimeError("JVM call result had unexpected length") raise RuntimeError("JVM call result had unexpected length")
elif type(ans[0]) != bytearray: elif type(ans[0]) != bytearray:
raise RuntimeError("JVM call result had first element of type " raise RuntimeError("JVM call result had first element of type "
+ type(ans[0]) + " which is not bytearray") + type(ans[0]).__name__ + " which is not bytearray")
elif type(ans[1]) != float: elif type(ans[1]) != float:
raise RuntimeError("JVM call result had second element of type " raise RuntimeError("JVM call result had second element of type "
+ type(ans[0]) + " which is not float") + type(ans[0]).__name__ + " which is not float")
return klass(_deserialize_double_vector(ans[0]), ans[1]) return klass(_deserialize_double_vector(ans[0]), ans[1])
def _serialize_rating(r): def _serialize_rating(r):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment