Note
Click here to download the full example code
Converts a Spark model¶
This example trains a spark model on the Iris datasets and converts it into ONNX.
Train a model¶
import os
import numpy
from pandas import DataFrame
import onnx
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import onnxruntime as rt
import skl2onnx
import pyspark
from pyspark.sql import SparkSession
from pyspark.ml.classification import LogisticRegression, LinearSVC
from pyspark.ml.linalg import VectorUDT, SparseVector
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import VectorAssembler, StringIndexer
import onnxmltools
from onnxconverter_common.data_types import FloatTensorType
from onnxmltools.convert import convert_sparkml
def start_spark(options=None):
import os
import sys
import pyspark
executable = sys.executable
os.environ["SPARK_HOME"] = pyspark.__path__[0]
os.environ["PYSPARK_PYTHON"] = executable
os.environ["PYSPARK_DRIVER_PYTHON"] = executable
builder = SparkSession.builder.appName("pyspark-unittesting").master("local[1]")
if options:
for k,v in options.items():
builder.config(k, v)
spark = builder.getOrCreate()
return spark
def stop_spark(spark):
spark.sparkContext.stop()
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
df = DataFrame(X_train, columns="x1 x2 x3 x4".split())
df['class'] = y_train
# df.to_csv("data_train.csv", index=False, header=False)
this_script_dir = os.path.abspath('.')
if os.name == 'nt' and os.environ.get('HADOOP_HOME') is None:
print('setting HADOOP_HOME to: ', this_script_dir)
os.environ['HADOOP_HOME'] = this_script_dir
spark_session = start_spark()
input_path = os.path.join(this_script_dir, "data_train.csv")
data = spark_session.createDataFrame(df)
feature_cols = data.columns[:-1]
assembler = VectorAssembler(inputCols=feature_cols, outputCol='features')
train_data = assembler.transform(data)
train_data = train_data.select(['features', 'class'])
label_indexer = StringIndexer(inputCol='class', outputCol='label').fit(train_data)
train_data = label_indexer.transform(train_data)
train_data = train_data.select(['features', 'label'])
train_data.show(10)
lr = LogisticRegression(maxIter=100, tol=0.0001)
model = lr.fit(train_data)
Out:
+-----------------+-----+
| features|label|
+-----------------+-----+
|[4.9,3.1,1.5,0.1]| 1.0|
|[5.9,3.0,5.1,1.8]| 2.0|
|[5.9,3.0,4.2,1.5]| 0.0|
|[5.4,3.7,1.5,0.2]| 1.0|
|[6.4,2.8,5.6,2.1]| 2.0|
|[5.6,3.0,4.5,1.5]| 0.0|
|[4.7,3.2,1.6,0.2]| 1.0|
|[7.7,2.6,6.9,2.3]| 2.0|
|[5.0,3.5,1.6,0.6]| 1.0|
|[5.4,3.9,1.3,0.4]| 1.0|
+-----------------+-----+
only showing top 10 rows
Convert a model into ONNX¶
initial_types = [('features', FloatTensorType([None, 4]))]
onx = convert_sparkml(model, 'sparkml logistic regression', initial_types)
stop_spark(spark_session)
Traceback (most recent call last):
File "somewhereonnxmltools-jenkins_39_std/onnxmltools/docs/examples/plot_convert_sparkml.py", line 98, in <module>
onx = convert_sparkml(model, 'sparkml logistic regression', initial_types)
File "somewhereonnxmltools-jenkins_39_std/onnxmltools/onnxmltools/convert/main.py", line 166, in convert_sparkml
return convert(model, name, initial_types, doc_string, target_opset, targeted_onnx,
File "somewhereonnxmltools-jenkins_39_std/onnxmltools/onnxmltools/convert/sparkml/convert.py", line 71, in convert
onnx_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx)
File "somewhereonnxmltools-jenkins_39_std/_venv/lib/python3.9/site-packages/onnxconverter_common/topology.py", line 704, in convert_topology
raise RuntimeError(("target_opset %d is higher than the number of the installed onnx package"
RuntimeError: target_opset 15 is higher than the number of the installed onnx package or the converter support (14).
Compute the predictions with onnxruntime¶
sess = rt.InferenceSession(onx.SerializeToString())
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run(
[label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(pred_onx)
Display the ONNX graph¶
Finally, let’s see the graph converted with onnxmltools.
import os
import matplotlib.pyplot as plt
from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
pydot_graph = GetPydotGraph(
onx.graph, name=onx.graph.name, rankdir="TB",
node_producer=GetOpNodeProducer(
"docstring", color="yellow", fillcolor="yellow", style="filled"))
pydot_graph.write_dot("model.dot")
os.system('dot -O -Gdpi=300 -Tpng model.dot')
image = plt.imread("model.dot.png")
fig, ax = plt.subplots(figsize=(40, 20))
ax.imshow(image)
ax.axis('off')
Versions used for this example
print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("onnxmltools: ", onnxmltools.__version__)
print("pyspark: ", pyspark.__version__)
Total running time of the script: ( 0 minutes 37.214 seconds)