Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
spark
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
cs525-sp18-g07
spark
Commits
95915f8b
Commit
95915f8b
authored
11 years ago
by
Tor Myklebust
Browse files
Options
Downloads
Patches
Plain Diff
First cut at python mllib bindings. Only LinearRegression is supported.
parent
d3b1af4b
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
+51
-0
51 additions, 0 deletions
...ain/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
python/pyspark/mllib.py
+114
-0
114 additions, 0 deletions
python/pyspark/mllib.py
with
165 additions
and
0 deletions
mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
0 → 100644
+
51
−
0
View file @
95915f8b
import
org.apache.spark.api.java.JavaRDD
import
org.apache.spark.mllib.regression._
import
java.nio.ByteBuffer
import
java.nio.ByteOrder
import
java.nio.DoubleBuffer
class
PythonMLLibAPI
extends
Serializable
{
def
deserializeDoubleVector
(
bytes
:
Array
[
Byte
])
:
Array
[
Double
]
=
{
val
packetLength
=
bytes
.
length
;
if
(
packetLength
<
16
)
{
throw
new
IllegalArgumentException
(
"Byte array too short."
);
}
val
bb
=
ByteBuffer
.
wrap
(
bytes
);
bb
.
order
(
ByteOrder
.
nativeOrder
());
val
magic
=
bb
.
getLong
();
if
(
magic
!=
1
)
{
throw
new
IllegalArgumentException
(
"Magic "
+
magic
+
" is wrong."
);
}
val
length
=
bb
.
getLong
();
if
(
packetLength
!=
16
+
8
*
length
)
{
throw
new
IllegalArgumentException
(
"Length "
+
length
+
"is wrong."
);
}
val
db
=
bb
.
asDoubleBuffer
();
val
ans
=
new
Array
[
Double
](
length
.
toInt
);
db
.
get
(
ans
);
return
ans
;
}
def
serializeDoubleVector
(
doubles
:
Array
[
Double
])
:
Array
[
Byte
]
=
{
val
len
=
doubles
.
length
;
val
bytes
=
new
Array
[
Byte
](
16
+
8
*
len
);
val
bb
=
ByteBuffer
.
wrap
(
bytes
);
bb
.
order
(
ByteOrder
.
nativeOrder
());
bb
.
putLong
(
1
);
bb
.
putLong
(
len
);
val
db
=
bb
.
asDoubleBuffer
();
db
.
put
(
doubles
);
return
bytes
;
}
def
trainLinearRegressionModel
(
dataBytesJRDD
:
JavaRDD
[
Array
[
Byte
]])
:
java.util.List
[
java.lang.Object
]
=
{
val
data
=
dataBytesJRDD
.
rdd
.
map
(
x
=>
deserializeDoubleVector
(
x
))
.
map
(
v
=>
LabeledPoint
(
v
(
0
),
v
.
slice
(
1
,
v
.
length
)));
val
model
=
LinearRegressionWithSGD
.
train
(
data
,
222
);
val
ret
=
new
java
.
util
.
LinkedList
[
java.lang.Object
]();
ret
.
add
(
serializeDoubleVector
(
model
.
weights
));
ret
.
add
(
model
.
intercept
:
java.lang.Double
);
return
ret
;
}
}
This diff is collapsed.
Click to expand it.
python/pyspark/mllib.py
0 → 100644
+
114
−
0
View file @
95915f8b
from
numpy
import
*
;
from
pyspark.serializers
import
NoOpSerializer
,
FramedSerializer
,
\
BatchedSerializer
,
CloudPickleSerializer
,
pack_long
#__all__ = ["train_linear_regression_model"];
# Double vector format:
#
# [8-byte 1] [8-byte length] [length*8 bytes of data]
#
# Double matrix format:
#
# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data]
#
# This is all in machine-endian. That means that the Java interpreter and the
# Python interpreter must agree on what endian the machine is.
def
deserialize_byte_array
(
shape
,
ba
,
offset
):
"""
Implementation detail. Do not use directly.
"""
ar
=
ndarray
(
shape
=
shape
,
buffer
=
ba
,
offset
=
offset
,
dtype
=
"
float64
"
,
\
order
=
'
C
'
);
return
ar
.
copy
();
def
serialize_double_vector
(
v
):
"""
Implementation detail. Do not use directly.
"""
if
(
type
(
v
)
==
ndarray
and
v
.
dtype
==
float64
and
v
.
ndim
==
1
):
length
=
v
.
shape
[
0
];
ba
=
bytearray
(
16
+
8
*
length
);
header
=
ndarray
(
shape
=
[
2
],
buffer
=
ba
,
dtype
=
"
int64
"
);
header
[
0
]
=
1
;
header
[
1
]
=
length
;
copyto
(
ndarray
(
shape
=
[
length
],
buffer
=
ba
,
offset
=
16
,
dtype
=
"
float64
"
),
v
);
return
ba
;
else
:
raise
TypeError
(
"
serialize_double_vector called on a non-double-vector
"
);
def
deserialize_double_vector
(
ba
):
"""
Implementation detail. Do not use directly.
"""
if
(
type
(
ba
)
==
bytearray
and
len
(
ba
)
>=
16
and
(
len
(
ba
)
&
7
==
0
)):
header
=
ndarray
(
shape
=
[
2
],
buffer
=
ba
,
dtype
=
"
int64
"
);
if
(
header
[
0
]
!=
1
):
raise
TypeError
(
"
deserialize_double_vector called on bytearray with
"
\
"
wrong magic
"
);
length
=
header
[
1
];
if
(
len
(
ba
)
!=
8
*
length
+
16
):
raise
TypeError
(
"
deserialize_double_vector called on bytearray with
"
\
"
wrong length
"
);
return
deserialize_byte_array
([
length
],
ba
,
16
);
else
:
raise
TypeError
(
"
deserialize_double_vector called on a non-bytearray
"
);
def
serialize_double_matrix
(
m
):
"""
Implementation detail. Do not use directly.
"""
if
(
type
(
m
)
==
ndarray
and
m
.
dtype
==
float64
and
m
.
ndim
==
2
):
rows
=
m
.
shape
[
0
];
cols
=
m
.
shape
[
1
];
ba
=
bytearray
(
24
+
8
*
rows
*
cols
);
header
=
ndarray
(
shape
=
[
3
],
buffer
=
ba
,
dtype
=
"
int64
"
);
header
[
0
]
=
2
;
header
[
1
]
=
rows
;
header
[
2
]
=
cols
;
copyto
(
ndarray
(
shape
=
[
rows
,
cols
],
buffer
=
ba
,
offset
=
24
,
dtype
=
"
float64
"
,
\
order
=
'
C
'
),
m
);
return
ba
;
else
:
print
type
(
m
);
print
m
.
dtype
;
print
m
.
ndim
;
raise
TypeError
(
"
serialize_double_matrix called on a non-double-matrix
"
);
def
deserialize_double_matrix
(
ba
):
"""
Implementation detail. Do not use directly.
"""
if
(
type
(
ba
)
==
bytearray
and
len
(
ba
)
>=
24
and
(
len
(
ba
)
&
7
==
0
)):
header
=
ndarray
(
shape
=
[
3
],
buffer
=
ba
,
dtype
=
"
int64
"
);
if
(
header
[
0
]
!=
2
):
raise
TypeError
(
"
deserialize_double_matrix called on bytearray with
"
\
"
wrong magic
"
);
rows
=
header
[
1
];
cols
=
header
[
2
];
if
(
len
(
ba
)
!=
8
*
rows
*
cols
+
24
):
raise
TypeError
(
"
deserialize_double_matrix called on bytearray with
"
\
"
wrong length
"
);
return
deserialize_byte_array
([
rows
,
cols
],
ba
,
24
);
else
:
raise
TypeError
(
"
deserialize_double_matrix called on a non-bytearray
"
);
class
LinearRegressionModel
:
_coeff
=
None
;
_intercept
=
None
;
def
__init__
(
self
,
coeff
,
intercept
):
self
.
_coeff
=
coeff
;
self
.
_intercept
=
intercept
;
def
predict
(
self
,
x
):
if
(
type
(
x
)
==
ndarray
):
if
(
x
.
ndim
==
1
):
return
dot
(
_coeff
,
x
)
-
_intercept
;
else
:
raise
RuntimeError
(
"
Bulk predict not yet supported.
"
);
elif
(
type
(
x
)
==
RDD
):
raise
RuntimeError
(
"
Bulk predict not yet supported.
"
);
else
:
raise
TypeError
(
"
Bad type argument to LinearRegressionModel::predict
"
);
def
train_linear_regression_model
(
sc
,
data
):
"""
Train a linear regression model on the given data.
"""
dataBytes
=
data
.
map
(
serialize_double_vector
);
sc
.
serializer
=
NoOpSerializer
();
dataBytes
.
cache
();
api
=
sc
.
_jvm
.
PythonMLLibAPI
();
ans
=
api
.
trainLinearRegressionModel
(
dataBytes
.
_jrdd
);
if
(
len
(
ans
)
!=
2
or
type
(
ans
[
0
])
!=
bytearray
or
type
(
ans
[
1
])
!=
float
):
raise
RuntimeError
(
"
train_linear_regression_model received garbage
"
\
"
from JVM
"
);
return
LinearRegressionModel
(
deserialize_double_vector
(
ans
[
0
]),
ans
[
1
]);
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment