.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/plot_parallel_execution_big_model.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_plot_parallel_execution_big_model.py: .. _l-plot-parallel-execution-big-models: ============================================== Multithreading with onnxruntime and big models ============================================== .. index:: thread, parallel, onnxruntime, gpu, big models Example :ref:`l-plot-parallel-execution` shows that parallelizing the inference over multiple GPUs on the same machine is worth doing it. However, this may not be possible when the model is too big to hold in the memory of a single GPU. In that case, we need to split the model and have each of the GPU run a piece of it. The strategy implemented in this example consists in dividing the model layers into consecutives pieces and push them on separate GPU. Let's assume a random network has two layers L1 and L2 roughly of the same size, GPU 1 will host L1, GPU 2 does the same with L1. A batch size contains 2 images. Their inference can decomposed the following way: * :math:`t_1`: image 1 is copied on GPU 1 * :math:`t_2`: L1 is processed * :math:`t_3`: output of L1 is copied to GPU 2, image 2 is copied to GPU 1 * :math:`t_4`: L1, L2 are processed. * :math:`t_5`: output of L1 is copied to GPU 2, output of L2 is copied to CPU * :math:`t_6`: L2 is processed * :math:`t_7`: output of L2 is copied to CPU This works if the copy accross GPU does not take too much time. The improvment should be even better for a longer batch. This example uses the same models as in :ref:`l-plot-parallel-execution`. .. contents:: :local: A model ======= Let's retrieve a not so big model. They are taken from the `ONNX Model Zoo `_ or can even be custom. .. GENERATED FROM PYTHON SOURCE LINES 42-103 .. code-block:: default import gc import os import pickle import urllib.request import threading import time import sys import tqdm import numpy from numpy.testing import assert_allclose import pandas import onnx import matplotlib.pyplot as plt import torch.cuda from onnxruntime import InferenceSession, get_all_providers from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 OrtValue as C_OrtValue) from onnxcustom.utils.onnx_split import split_onnx from onnxcustom.utils.onnxruntime_helper import get_ort_device_from_session def download_file(url, name, min_size): if not os.path.exists(name): print(f"download '{url}'") with urllib.request.urlopen(url) as u: content = u.read() if len(content) < min_size: raise RuntimeError( f"Unable to download '{url}' due to\n{content}") print(f"downloaded {len(content)} bytes.") with open(name, "wb") as f: f.write(content) else: print(f"'{name}' already downloaded") small = "custom" if "custom" in sys.argv else "big" not in sys.argv if small == "custom": model_name = "gpt2.onnx" url_name = None maxN, stepN, repN = 10, 1, 4 big_model = True elif small: model_name = "mobilenetv2-10.onnx" url_name = ("https://github.com/onnx/models/raw/main/vision/" "classification/mobilenet/model") maxN, stepN, repN = 81, 2, 4 big_model = False else: model_name = "resnet18-v1-7.onnx" url_name = ("https://github.com/onnx/models/raw/main/vision/" "classification/resnet/model") maxN, stepN, repN = 81, 2, 4 big_model = False if url_name is not None: url_name += "/" + model_name download_file(url_name, model_name, 100000) .. rst-class:: sphx-glr-script-out .. code-block:: none 'mobilenetv2-10.onnx' already downloaded .. GENERATED FROM PYTHON SOURCE LINES 104-108 GPU === Let's check first if it is possible. .. GENERATED FROM PYTHON SOURCE LINES 108-121 .. code-block:: default has_cuda = "CUDAExecutionProvider" in get_all_providers() if not has_cuda: print(f"No CUDA provider was detected in {get_all_providers()}.") n_gpus = torch.cuda.device_count() if has_cuda else 0 if n_gpus == 0: print("No GPU or one GPU was detected.") elif n_gpus == 1: print("1 GPU was detected.") else: print(f"{n_gpus} GPUs were detected.") .. rst-class:: sphx-glr-script-out .. code-block:: none No GPU or one GPU was detected. .. GENERATED FROM PYTHON SOURCE LINES 122-130 Split the model =============== It is an ONNX graph. There is no notion of layers. The function :func:`split_onnx ` first detects possible cutting points (breaking the connexity of the graph) Then it is just finding the best cutting points to split the model into pieces of roughly the same size. .. GENERATED FROM PYTHON SOURCE LINES 130-145 .. code-block:: default with open(model_name, "rb") as f: model = onnx.load(f) if model_name == "resnet18-v1-7.onnx": # best cutting point to parallelize on 2 GPUs cut_points = ["resnetv15_stage3_activation0"] n_parts = None else: cut_points = None n_parts = max(n_gpus, 2) pieces = split_onnx(model, n_parts=n_parts, cut_points=cut_points, verbose=2) .. rst-class:: sphx-glr-script-out .. code-block:: none [split_onnx] prepare splitting 105 nodes in 2 parts. [OnnxSplitting] look for cutting points in 105 nodes. 0%| | 0/103 [00:00 None [OnnxSplitting] found pos=45, size_1=7195944, size_2=6768415=0.48, split='624' [split_onnx] splits: [0, 45, 49], names=['624'] [OnnxSplitting] part 1: #nodes=95/105, size=7195944/13964359=0.52 [OnnxSplitting] part 2: #nodes=10/105, size=6768415/13964359=0.48 .. GENERATED FROM PYTHON SOURCE LINES 146-148 Pieces are roughly of the same size. Let's save them on disk. .. GENERATED FROM PYTHON SOURCE LINES 148-156 .. code-block:: default piece_names = [] for i, piece in enumerate(pieces): name = f"piece-{os.path.splitext(model_name)[0]}-{i}.onnx" piece_names.append(name) with open(name, "wb") as f: f.write(piece.SerializeToString()) .. GENERATED FROM PYTHON SOURCE LINES 157-162 Discrepancies? ============== We need to make sure the split model is equivalent to the original one. Some data first. .. GENERATED FROM PYTHON SOURCE LINES 162-165 .. code-block:: default sess_full = InferenceSession(model_name, providers=["CPUExecutionProvider"]) .. GENERATED FROM PYTHON SOURCE LINES 166-167 inputs. .. GENERATED FROM PYTHON SOURCE LINES 167-175 .. code-block:: default for i in sess_full.get_inputs(): print(f"input {i}, name={i.name!r}, type={i.type}, shape={i.shape}") input_name = i.name input_shape = list(i.shape) if input_shape[0] in [None, "batch_size", "N"]: input_shape[0] = 1 .. rst-class:: sphx-glr-script-out .. code-block:: none input NodeArg(name='input', type='tensor(float)', shape=['batch_size', 3, 224, 224]), name='input', type=tensor(float), shape=['batch_size', 3, 224, 224] .. GENERATED FROM PYTHON SOURCE LINES 176-177 outputs. .. GENERATED FROM PYTHON SOURCE LINES 177-186 .. code-block:: default output_name = None for i in sess_full.get_outputs(): print(f"output {i}, name={i.name!r}, type={i.type}, shape={i.shape}") if output_name is None: output_name = i.name print(f"input_name={input_name!r}, output_name={output_name!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none output NodeArg(name='output', type='tensor(float)', shape=['batch_size', 1000]), name='output', type=tensor(float), shape=['batch_size', 1000] input_name='input', output_name='output' .. GENERATED FROM PYTHON SOURCE LINES 187-188 data .. GENERATED FROM PYTHON SOURCE LINES 188-199 .. code-block:: default if model_name == "gpt2.onnx": with open("encoded_tensors-gpt2.pkl", "rb") as f: [encoded_tensors, labels] = pickle.load(f) imgs = [x["input_ids"].numpy() for x in encoded_tensors[:maxN]] else: imgs = [numpy.random.rand(*input_shape).astype(numpy.float32) for i in range(maxN)] .. GENERATED FROM PYTHON SOURCE LINES 200-201 The split model. .. GENERATED FROM PYTHON SOURCE LINES 201-211 .. code-block:: default sess_split = [] for name in piece_names: try: sess_split.append(InferenceSession( name, providers=["CPUExecutionProvider"])) except InvalidArgument as e: raise RuntimeError(f"Part {name!r} cannot be loaded.") from e input_names = [sess.get_inputs()[0].name for sess in sess_split] .. GENERATED FROM PYTHON SOURCE LINES 212-213 We are ready to compute the outputs from both models. .. GENERATED FROM PYTHON SOURCE LINES 213-224 .. code-block:: default expected = sess_full.run(None, {input_name: imgs[0]})[0] x = imgs[0] for sess, name in zip(sess_split, input_names): feeds = {name: x} x = sess.run(None, feeds)[0] diff = numpy.abs(expected - x).max() print(f"Max difference: {diff}") .. rst-class:: sphx-glr-script-out .. code-block:: none Max difference: 0.0 .. GENERATED FROM PYTHON SOURCE LINES 225-231 Everything works. Parallelization on GPU ====================== First the implementation of a sequence. .. GENERATED FROM PYTHON SOURCE LINES 231-246 .. code-block:: default def sequence_ort_value(sesss, imgs): assert len(sesss) == 1 sess = sesss[0] ort_device = get_ort_device_from_session(sess) res = [] for img in imgs: ov = C_OrtValue.ortvalue_from_numpy(img, ort_device) out = sess._sess.run_with_ort_values( {input_name: ov}, [output_name], None)[0] res.append(out.numpy()) return res, {}, True .. GENERATED FROM PYTHON SOURCE LINES 247-248 And the parallel execution. .. GENERATED FROM PYTHON SOURCE LINES 248-360 .. code-block:: default class MyThreadOrtValue(threading.Thread): def __init__(self, sess, batch_size, next_thread=None, wait_time=1e-4): threading.Thread.__init__(self) if batch_size <= 0: raise ValueError(f"batch_size={batch_size} must be positive.") self.sess = sess self.wait_time = wait_time self.ort_device = get_ort_device_from_session(self.sess) self.next_thread = next_thread self.input_name = self.sess.get_inputs()[0].name self.output_name = self.sess.get_outputs()[0].name self.batch_size = batch_size # for the execution self.inputs = [] self.outputs = [] self.waiting_time0 = 0 self.waiting_time = 0 self.run_time = 0 self.copy_time_1 = 0 self.copy_time_2 = 0 self.twait_time = 0 self.total_time = 0 def append(self, pos, img): if not isinstance(img, numpy.ndarray): raise TypeError(f"numpy array expected not {type(img)}.") if pos >= self.batch_size or pos < 0: raise RuntimeError( f"Cannot append an image, pos={pos} no in [0, {self.batch_size}[. " f"The thread should be finished.") self.inputs.append((pos, img)) def run(self): ort_device = self.ort_device sess = self.sess._sess processed = 0 while processed < self.batch_size: # wait for an image tw = time.perf_counter() while processed >= len(self.inputs): self.waiting_time += self.wait_time if len(self.inputs) == 0: self.waiting_time0 += self.wait_time time.sleep(self.wait_time) pos, img = self.inputs[processed] t0 = time.perf_counter() ov = C_OrtValue.ortvalue_from_numpy(img, ort_device) t1 = time.perf_counter() out = sess.run_with_ort_values({self.input_name: ov}, [self.output_name], None)[0] t2 = time.perf_counter() cpu_res = out.numpy() t3 = time.perf_counter() self.outputs.append((pos, cpu_res)) # sent the result to the next part if self.next_thread is not None: self.next_thread.append(pos, cpu_res) self.inputs[processed] = None # deletion processed += 1 t4 = time.perf_counter() self.copy_time_1 += t1 - t0 self.run_time += t2 - t1 self.copy_time_2 += t3 - t2 self.twait_time += t0 - tw self.total_time += t4 - tw def parallel_ort_value(sesss, imgs, wait_time=1e-4): threads = [] for i in range(len(sesss)): sess = sesss[-i - 1] next_thread = threads[-1] if i > 0 else None th = MyThreadOrtValue( sess, len(imgs), next_thread, wait_time=wait_time) threads.append(th) threads = list(reversed(threads)) for i, img in enumerate(imgs): threads[0].append(i, img) for t in threads: t.start() res = [] th = threads[-1] th.join() res.extend(th.outputs) indices = [r[0] for r in res] order = list(sorted(indices)) == indices res.sort() res = [r[1] for r in res] times = {"wait": [], "wait0": [], "copy1": [], "copy2": [], "run": [], "ttime": [], "wtime": []} for t in threads: times["wait"].append(t.waiting_time) times["wait0"].append(t.waiting_time0) times["copy1"].append(t.copy_time_1) times["copy2"].append(t.copy_time_2) times["run"].append(t.run_time) times["ttime"].append(t.total_time) times["wtime"].append(t.twait_time) return res, times, order .. GENERATED FROM PYTHON SOURCE LINES 361-367 Functions ========= The benchmark runs one function on all batch sizes then deleted the model before going to the next function in order to free the GPU memory. .. GENERATED FROM PYTHON SOURCE LINES 367-578 .. code-block:: default def benchmark(fcts, model_name, piece_names, imgs, stepN=1, repN=4): data = [] Ns = list(range(1, len(imgs), stepN)) ns_name = {} results = {} for name, build_fct, fct in fcts: ns_name[name] = [] results[name] = [] sesss = build_fct() fct(sesss, imgs[:2]) for N in tqdm.tqdm(Ns): all_times = [] begin = time.perf_counter() for i in range(repN): r, times, order = fct(sesss, imgs[:N]) all_times.append(times) end = (time.perf_counter() - begin) / repN times = {} for key in all_times[0].keys(): times[key] = sum(numpy.array(t[key]) for t in all_times) / repN obs = {'n_imgs': len(imgs), 'maxN': maxN, 'stepN': stepN, 'repN': repN, 'batch_size': N, 'n_threads': len(sesss), 'name': name} obs.update({"n_imgs": len(r), "time": end}) obs['order'] = order if len(times) > 0: obs.update( {f"wait0_{i}": t for i, t in enumerate(times["wait0"])}) obs.update( {f"wait_{i}": t for i, t in enumerate(times["wait"])}) obs.update( {f"copy1_{i}": t for i, t in enumerate(times["copy1"])}) obs.update( {f"copy2_{i}": t for i, t in enumerate(times["copy2"])}) obs.update({f"run_{i}": t for i, t in enumerate(times["run"])}) obs.update( {f"ttime_{i}": t for i, t in enumerate(times["ttime"])}) obs.update( {f"wtime_{i}": t for i, t in enumerate(times["wtime"])}) ns_name[name].append(len(r)) results[name].append((r, obs)) data.append(obs) del sesss gc.collect() # Let's maje sure again that the outputs are the same when the inference # is parallelized. names = list(ns_name) baseline = ns_name[names[0]] for name in names[1:]: if ns_name[name] != baseline: raise RuntimeError( f"Cannot compare experiments as it returns differents number of results, " f"ns_name={ns_name}, obs={obs}.") baseline = results[names[0]] for name in names[1:]: if len(results[name]) != len(baseline): raise RuntimeError("Cannot compare.") for i1, ((b, o1), (r, o2)) in enumerate(zip(baseline, results[name])): if len(b) != len(r): raise RuntimeError( f"Cannot compare: len(b)={len(b)} != len(r)={len(r)}.") for i2, (x, y) in enumerate(zip(b, r)): try: assert_allclose(x, y, atol=1e-3) except AssertionError as e: raise AssertionError( f"Issue with baseline={names[0]!r} and {name!r}, " f"i1={i1}/{len(baseline)}, i2={i2}/{len(b)}\n" f"o1={o1}\no2{o2}") from e return pandas.DataFrame(data) def make_plot(df, title): if df is None: return None fig, ax = plt.subplots(3, 4, figsize=(12, 9), sharex=True) # perf a = ax[0, 0] perf = df.pivot(index="n_imgs", columns="name", values="time") num = perf["parallel"].copy() div = perf.index.values perf.plot(logy=True, ax=a) a.set_title("time(s)", fontsize="x-small") a.legend(fontsize="x-small") a.set_ylabel("seconds", fontsize="x-small") a = ax[0, 1] for c in perf.columns: perf[c] /= div perf.plot(ax=a) a.set_title("time(s) / batch_size", fontsize="x-small") a.legend(fontsize="x-small") a.set_ylim([0, None]) a.set_ylabel("seconds", fontsize="x-small") a.set_xlabel("batch size", fontsize="x-small") a = ax[0, 2] perf["perf gain"] = (perf["sequence"] - perf["parallel"]) / perf["sequence"] wcol = [] wcol0 = [] cs = [] for i in range(0, 4): c = f"wait_{i}" if c not in df.columns: break wcol.append(c) wcol0.append(f"wait0_{i}") p = df.pivot(index="n_imgs", columns="name", values=c) perf[f"wait_{i}"] = p["parallel"].values / num cs.append(f"wait_{i}") n_parts = len(cs) perf["wait"] = perf[cs].sum(axis=1) perf[["perf gain", "wait"] + cs].plot(ax=a) a.set_title("gain / batch_size\n((baseline - parallel) / baseline", fontsize="x-small") a.legend(fontsize="x-small") a.set_ylim([0, None]) a.set_ylabel("%", fontsize="x-small") a.set_xlabel("batch size", fontsize="x-small") # wait a = ax[1, 0] wait0 = df[["n_imgs"] + wcol0].set_index("n_imgs") wait0.plot(ax=a) a.set_title("Time waiting for the first image per thread", fontsize="x-small") a.legend(fontsize="x-small") a.set_ylabel("seconds", fontsize="x-small") a = ax[1, 1] wait = df[["n_imgs"] + wcol].set_index("n_imgs") wait.plot(ax=a) a.set_title("Total time waiting per thread", fontsize="x-small") a.legend(fontsize="x-small") a.set_ylabel("seconds", fontsize="x-small") a = ax[1, 2] wait = df[["n_imgs"] + wcol] div = wait["n_imgs"] wait = wait.set_index("n_imgs") for c in wait.columns: wait[c] /= div.values wait.plot(ax=a) a.set_title( "Total time waiting per thread\ndivided by batch size", fontsize="x-small") a.legend() a.set_ylim([0, None]) a.set_ylabel("seconds", fontsize="x-small") a.set_xlabel("batch size", fontsize="x-small") a.set_xlabel("batch size", fontsize="x-small") # ttime a = ax[0, 3] ttimes = [c for c in df.columns if c.startswith('ttime_')] n_threads = len(ttimes) sub = df.loc[~df.run_0.isnull(), ["n_imgs", "time"] + ttimes].copy() for c in sub.columns[1:]: sub[c] /= sub["n_imgs"] sub.set_index("n_imgs").plot(ax=a, logy=True) a.set_title("Total time (parallel)\ndivided by batch size", fontsize="x-small") a.set_ylabel("seconds", fontsize="x-small") a.set_xlabel("batch size", fontsize="x-small") a = ax[1, 3] run = [c for c in df.columns if c.startswith('run_')] sub = df.loc[~df.run_0.isnull(), ["n_imgs", "time"] + run].copy() for c in sub.columns[2:]: sub[c] /= sub["time"] sub["time"] = 1 sub.set_index("n_imgs").plot(ax=a) a.set_title( "Ratio running time per thread / total time\ndivided by batch size", fontsize="x-small") a.set_ylabel("seconds", fontsize="x-small") a.set_xlabel("batch size", fontsize="x-small") a.legend(fontsize="x-small") # other time pos = [2, 0] cols = ['wtime', 'copy1', 'run', 'copy2', 'ttime'] for nth in range(n_threads): a = ax[tuple(pos)] cs = [f"{c}_{nth}" for c in cols] sub = df.loc[~df.run_0.isnull(), ["n_imgs", "time"] + cs].copy() for c in cs: sub[c] /= sub[f"ttime_{nth}"] sub.set_index("n_imgs")[cs[:-1]].plot.area(ax=a) a.set_title(f"Part {nth + 1}/{n_threads}", fontsize="x-small") a.set_xlabel("batch size", fontsize="x-small") a.legend(fontsize="x-small") pos[1] += 1 if pos[1] >= ax.shape[1]: pos[1] = 0 pos[0] += 1 fig.suptitle(f"{title} - {n_parts} splits") fig.savefig(f"img-{n_parts}-splits-{title.replace(' ', '_')}.png", dpi=200) return ax .. GENERATED FROM PYTHON SOURCE LINES 579-582 Benchmark ========= .. GENERATED FROM PYTHON SOURCE LINES 582-617 .. code-block:: default def build_sequence(): return [InferenceSession(model_name, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])] def build_parellel_pieces(): sesss = [] for i in range(len(piece_names)): print(f"Initialize device {i} with {piece_names[i]!r}") sesss.append( InferenceSession( piece_names[i], providers=["CUDAExecutionProvider", "CPUExecutionProvider"], provider_options=[{"device_id": i}, {}])) return sesss if n_gpus > 1: print("ORT // GPUs") df = benchmark(model_name=model_name, piece_names=piece_names, imgs=imgs, stepN=stepN, repN=repN, fcts=[('sequence', build_sequence, sequence_ort_value), ('parallel', build_parellel_pieces, parallel_ort_value)]) df.reset_index(drop=False).to_csv("ort_gpus_piece.csv", index=False) title = os.path.splitext(model_name)[0] else: print("No GPU is available but data should be like the following.") df = pandas.read_csv("data/ort_gpus_piece.csv") title = "Saved mobilenet" df .. rst-class:: sphx-glr-script-out .. code-block:: none No GPU is available but data should be like the following. .. raw:: html
index n_imgs maxN stepN repN batch_size n_threads name time order wait0_0 wait0_1 wait0_2 wait0_3 wait_0 wait_1 wait_2 wait_3 copy1_0 copy1_1 copy1_2 copy1_3 copy2_0 copy2_1 copy2_2 copy2_3 run_0 run_1 run_2 run_3 ttime_0 ttime_1 ttime_2 ttime_3 wtime_0 wtime_1 wtime_2 wtime_3
0 0 1 81 2 4 1 1 sequence 0.001724 True NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 1 3 81 2 4 3 1 sequence 0.005125 True NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 2 5 81 2 4 5 1 sequence 0.008465 True NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 3 7 81 2 4 7 1 sequence 0.011681 True NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
4 4 9 81 2 4 9 1 sequence 0.015074 True NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
75 75 71 81 2 4 71 4 parallel 0.105721 True 0.0 0.000725 0.001075 0.001275 0.0 0.023900 0.032650 0.037625 0.012286 0.004116 0.004901 0.005494 0.003491 0.005229 0.006400 0.000827 0.087326 0.031705 0.017976 0.022040 0.103494 0.103850 0.104286 0.104414 0.000059 0.062473 0.074706 0.075935
76 76 73 81 2 4 73 4 parallel 0.110922 True 0.0 0.000725 0.001125 0.001325 0.0 0.022225 0.035425 0.038775 0.012716 0.004558 0.004842 0.005608 0.003529 0.005694 0.007135 0.000920 0.091783 0.036628 0.019098 0.021592 0.108431 0.109037 0.109470 0.109582 0.000063 0.061692 0.078089 0.081336
77 77 75 81 2 4 75 4 parallel 0.114708 True 0.0 0.000725 0.001050 0.001200 0.0 0.025325 0.037075 0.040750 0.012794 0.004679 0.005102 0.005848 0.003819 0.005809 0.007230 0.000883 0.095368 0.034001 0.018379 0.024682 0.112416 0.112837 0.113262 0.113380 0.000064 0.067942 0.082241 0.081836
78 78 77 81 2 4 77 4 parallel 0.115314 True 0.0 0.000750 0.001125 0.001300 0.0 0.025375 0.036250 0.040950 0.013208 0.004503 0.005343 0.005850 0.003750 0.005939 0.007678 0.000808 0.095631 0.034800 0.019410 0.025095 0.113017 0.113407 0.113855 0.113980 0.000064 0.067827 0.081103 0.082087
79 79 79 81 2 4 79 4 parallel 0.118628 True 0.0 0.000725 0.001100 0.001275 0.0 0.025800 0.038175 0.042500 0.013690 0.004642 0.005369 0.006016 0.003820 0.005923 0.007324 0.000858 0.098371 0.036499 0.019643 0.027283 0.116312 0.116742 0.117168 0.117285 0.000066 0.069318 0.084512 0.082985

80 rows × 38 columns



.. GENERATED FROM PYTHON SOURCE LINES 618-619 Plots. .. GENERATED FROM PYTHON SOURCE LINES 619-623 .. code-block:: default ax = make_plot(df, title) ax .. image-sg:: /gyexamples/images/sphx_glr_plot_parallel_execution_big_model_001.png :alt: Saved mobilenet - 4 splits, time(s), time(s) / batch_size, gain / batch_size ((baseline - parallel) / baseline, Total time (parallel) divided by batch size, Time waiting for the first image per thread, Total time waiting per thread, Total time waiting per thread divided by batch size, Ratio running time per thread / total time divided by batch size, Part 1/4, Part 2/4, Part 3/4, Part 4/4 :srcset: /gyexamples/images/sphx_glr_plot_parallel_execution_big_model_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none array([[, , , ], [, , , ], [, , , ]], dtype=object) .. GENERATED FROM PYTHON SOURCE LINES 624-629 Recorded results ================ The parallelization on multiple GPUs did work. With a model resnet18. .. GENERATED FROM PYTHON SOURCE LINES 629-636 .. code-block:: default data = pandas.read_csv("data/ort_gpus_piece_resnet18.csv") df = pandas.DataFrame(data) ax = make_plot(df, "Saved resnet 18") ax .. image-sg:: /gyexamples/images/sphx_glr_plot_parallel_execution_big_model_002.png :alt: Saved resnet 18 - 2 splits, time(s), time(s) / batch_size, gain / batch_size ((baseline - parallel) / baseline, Total time (parallel) divided by batch size, Time waiting for the first image per thread, Total time waiting per thread, Total time waiting per thread divided by batch size, Ratio running time per thread / total time divided by batch size, Part 1/2, Part 2/2 :srcset: /gyexamples/images/sphx_glr_plot_parallel_execution_big_model_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none array([[, , , ], [, , , ], [, , , ]], dtype=object) .. GENERATED FROM PYTHON SOURCE LINES 637-638 With `GPT2 `_ .. GENERATED FROM PYTHON SOURCE LINES 638-649 .. code-block:: default if os.path.exists("data/ort_gpus_piece_gpt2.csv"): data = pandas.read_csv("data/ort_gpus_piece_gpt2.csv") df = pandas.DataFrame(data) ax = make_plot(df, "Saved GPT2") else: print("No result yet.") ax = None ax # plt.show() .. rst-class:: sphx-glr-script-out .. code-block:: none No result yet. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 17.796 seconds) .. _sphx_glr_download_gyexamples_plot_parallel_execution_big_model.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_parallel_execution_big_model.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_parallel_execution_big_model.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_