2021-08-12 A few tricks for tf2onnx#

A few things I tend to forget. To run a specific test on a specific opset.

python tests/test_backend.py --opset 12 BackendTests.test_rfft2d_ops_specific_dimension

Optimisation of an onnx file. It applies the whole list of optimizers available in tensorflow-onnx.

import logging
import onnx
from onnx import helper
from tf2onnx.graph import GraphUtil
from tf2onnx import logging, optimizer, constants
from tf2onnx.late_rewriters import rewrite_channels_first, rewrite_channels_last

logging.basicConfig(level=logging.DEBUG)

def load_graph(fname, target):
    model_proto = onnx.ModelProto()
    with open(fname, "rb") as f:
        data = f.read()
        model_proto.ParseFromString(data)
    g = GraphUtil.create_graph_from_onnx_model(model_proto, target)
    return g, model_proto

def optimize(input, output):
    g, org_model_proto = load_graph(input, [])
    if g.is_target(constants.TARGET_CHANNELS_FIRST):
        g.reset_nodes(rewrite_channels_first(g, g.get_nodes()))
    if g.is_target(constants.TARGET_CHANNELS_LAST):
        g.reset_nodes(rewrite_channels_last(g, g.get_nodes()))
    g = optimizer.optimize_graph(g)
    onnx_graph = g.make_graph(
        org_model_proto.graph.doc_string + " (+tf2onnx/onnx-optimize)")
    kwargs = GraphUtil.get_onnx_model_properties(org_model_proto)
    model_proto = helper.make_model(onnx_graph, **kwargs)
    with open(output, "wb") as f:
        f.write(model_proto.SerializeToString())

optimize("debug_noopt.onnx", "debug_opt.onnx")