MRT Graph API

mrt.tfm_pass

Collection of MRT pass tions. Graph-level symbol helper function for MRT. Stage-level symbol pass designation for MRT.

mrt.tfm_pass.calculate_ops(symbol, params, normalize=True)

Customized graph-level topo pass definition.

Calulate the numbe of operations with respect to the given input shape.

mrt.tfm_pass.fuse_transpose(symbol, params)

Customized graph-level topo pass definition.

Equivalent graph transformation. Fuse or swap the current operator with the former Transpose.

mrt.tfm_pass.rewrite(symbol, params)

Customized graph-level topo pass definition.

Equivalent graph transformation. Rewrite the current symbol into equivalent ones that is feasible for quantization.

mrt.tfm_pass.quantize(symbol, params, th_dict, precs, scales, op_input_precs, restore_names, shift_bits, softmax_lambd)

Customized graph-level topo pass definition.

MRT quantization process function.

The original graph can be denoted as follows:

Case 1. Enable Restore

X1   X2  ...  Xn
xs1  xs2 ...  xsn
\    |       /
 \   |      /
  \  |     /
  current_op

where ‘xs1’, ‘xs2’, etc are respectively stand for the scale of the inputs.

The original graph will be restored into:

X1   X2  ... Xn
xs1  xs2 ... xsn
|    |       |
|    |       |
RX1  RX2 ... RXn
1    1       1
\    |       /
 \   |      /
  \  |     /
  current_op
     1
     |
     |
    ...

where ‘RX1’, ‘RX2’ … stand for restored symbol with unit scale.

Case 2. Other Cases

The original graph will be quantized into:

        X1   X2  ... Xn
        xs1  xs2 ... xsn
        |    |       |
        |    |       |
        QX1  QX2 ... QXn
        \    |       /
         \   |      /
          \  |     /
[equivalent quantization subgraph]
             |
             |
            ...

For either cases, if the ‘current_op’ belongs to one of the output symbols, it will quantized again with respect to the specified output precision.

mrt.tfm_pass.prepare_for_compile(symbol, params)

Customized graph-level topo pass definition.

Equivalent graph transformation. Convert quantized symbols into feasible ones for compile.

mrt.tfm_pass.to_cvm(symbol, params)

Customized graph-level topo pass definition.

Attribute cast to cvm symbol attributes. Mxnet symbols cast to CVM symbols using cvm-runtime libs.

mrt.tfm_pass.fuse_multiple_inputs(sym, params)

Customized graph-level topo pass definition.

The original graph has multiple inputs.

Inp1 Inp2  ...  Inpn
 |   |           |
 |   |           |
... ...    ...  ...

Where ‘Inp1’, ‘Inp2’ and so on stands for the input symbol.

The original graph will be transformed into:

               data_sum
             /     |      \
            /      |       \
           /       |        \
          /        |         \
         /         |          \
        /          |           \
slice_axis     slice_axis ...   slice_axis
axis: 0        axis: 0          axis: 0
begin:0        begin:end1       begin: end(n-1)
end:pd1        end:end1+pd2     end:end(n-1)+pdn
   /               |               \
   |               |                |
reshape        reshape          reshape
shape:shp1     shape:shp2       shape:shpn
   |               |                |
   |               |                |
  Inp1            Inp2    ...      Inpn
   |               |                |
   |               |                |
  ...             ...     ...      ...

where ‘data_sum’ is the tiled and concatenated form of all the inputs, and ‘pd1’, ‘pd2’, …, ‘pdn’ is the infer shape product of ‘inp’s, and ‘shp1’, ‘shp2’, …, ‘shpn’ is the infer shape if ‘Inps’.

mrt.tfm_pass.name_duplicate_check(symbol, params)

Check whether duplicate names exist in the graph.

mrt.tfm_pass.params_unique(symbol, params)

Remove duplicate keys params dict.

mrt.tfm_pass.input_name_replace(symbol, params)

Customized graph-level topo pass definition.

Replace the single input name to ‘data’.

mrt.tfm_pass.fuse_constant(symbol, params)

Customized graph-level topo pass definition.

Fix the constant operator shape with respect to the infer shape.

mrt.tfm_pass.attach_input_shape(symbol, params, input_shapes)

Customized graph-level topo pass definition.

Attach the infer shapes for the graph.

Parameters
  • symbol (mxnet.symbol) – The graph symbols.

  • params (dict) – The graph parameters dict.

  • input_shapes (dict) – The name-shape map.

Returns

ret – The symbol-param tuple after input shapes attachment.

Return type

tuple

mrt.tfm_pass.infer_shape(symbol, params, input_shape=None)

Customized graph-level topo pass definition.

Collect the infer shapes from graph.

Parameters
  • symbol (mxnet.symbol) – The graph symbols.

  • params (dict) – The graph parameters dict.

  • input_shape (tuple) – The input shape of the data.

Returns

ret – The name-shape map.

Return type

dict

mrt.tfm_pass._collect_attribute(op, **kwargs)

Collect the attribute value from the specified operator.

Parameters

op (mxnet.symbol) – The input operator.

Returns

ret – The output operator.

Return type

mxnet.symbol

mrt.tfm_pass.collect_op_names(symbol, params)

Customized graph-level topo pass definition.

Collect all kinds of operators that exist in the graph.

Parameters
  • symbol (mxnet.symbol) – The graph symbols.

  • params (dict) – The graph parameters dict.

Returns

ret – The opname set.

Return type

set

mrt.tfm_pass.fuse_multiple_outputs(symbol, params)

Customized symbol-level topo pass definition.

Symbol-level multiple-outputs-fusion pass.

          X
          |
          |
     SliceChannel
     /    |     \
    /     |      \
   /      |       \
  /       |        \
A1        A2   ...  An

The original graph will be transformed into:

                   X
             /     |      \
            /      |       \
           /       |        \
          /        |         \
         /         |          \
        /          |           \
slice_axis     slice_axis ...   slice_axis
axis: ich      axis: ich        axis: ich
begin:0        begin:stride     begin: (n-1)*stride
end:stride     end:2*stride     end:n*stride
   /               |               \
  A1               A2      ...      An

where ‘ich’ is the attribute ‘axis’ along which to split, and

\[stride = dim / num\_outputs\]

where ‘xshape’ is infer shape of ‘X’ and ‘num_ouputs’ is the number of splits.

mrt.tfm_pass._get_opt(out, lambd)

Calibrate the MRT model after setting mrt data.

\[mean\_v = mean(out)\]
\[n = product(shape(out))\]
\[sqrt\_n = sqrt(n)\]
\[std = norm(out-mean) / sqrt\_n\]
\[alpha = |mean| + lambd * std\]
\[absmax = ||out||_2\]

Case 1. alpha < 0.95 * absmax

opt = alpha

Case 2. Other Cases

out = absmax

Parameters
  • out (nd.NDArray) – The graph level output.

  • lambd (float) – Hyperparameter that set the alpha of data.

Returns

opt – The opt value.

Return type

float

mrt.tfm_pass.sym_calibrate(symbol, params, data, **kwargs)

Customized graph-level topo pass definition.

Calibrate the MRT model after setting mrt data.

Parameters
  • symbol (mxnet.symbol) – The graph symbols.

  • params (dict) – The graph parameters dict.

  • data (nd.NDArray) – The input data.

Returns

ret – The threshold dict after calibration.

Return type

dict

mrt.tfm_pass.convert_params_dtype(params, src_dtypes=['float32', 'float64'], dest_dtype='float64')

Convert the source data type into to target data type.

Parameters
  • params (nd.NDArray) – The input data to be converted.

  • src_dtypes (str or list of str) – The source data type(s)

Returns

ret – The precision of the input.

Return type

int

mrt.sym_utils

Collection of MRT pass tool functions. Topo sequence helper function for MxNet graph parsing. Simplification of MRT graph traversal.

Only key part of util functions are elaborated.

mrt.sym_utils.is_op(sym, params)

Judge whether the input symbol is an operator.

Parameters
  • sym (mxnet.symbol) – The input symbol.

  • params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.

Returns

ret – Return False if the input symbol is not an operator, else True.

Return type

bool

mrt.sym_utils.is_var(sym, params)

Judge whether the input symbol is a variable.

Parameters
  • sym (mxnet.symbol) – The input symbol.

  • params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.

Returns

ret – Return False if the input symbol is not a variable, else True.

Return type

bool

mrt.sym_utils.is_params(sym, params)

Judge whether the input symbol is a parameter.

Parameters
  • sym (mxnet.symbol) – The input symbol.

  • params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.

Returns

ret – Return False if the input symbol is not a parameter, else True.

Return type

bool

mrt.sym_utils.is_inputs(sym, params)

Judge whether the input symbol is an input.

Parameters
  • sym (mxnet.symbol) – The input symbol.

  • params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.

Returns

ret – Return False if the input symbol is not an input, else True.

Return type

bool

mrt.sym_utils.nd_array(source_array, ctx=None, dtype='float64')

Convert the source array into MxNet NDArray.

Parameters
  • source_array (tuple or list) – The input array to be converted.

  • ctx (mxnet.context) – The context on which to store the converted array.

  • dtype (str) – The entry data type of the converted array.

Returns

ret – The converted result.

Return type

mxnet.NDArray

mrt.sym_utils.nd_arange(*args, **kwargs)

MRT wrapper method for mxnet.NDArray.arange.

mrt.sym_utils.nd_full(*args, **kwargs)

MRT wrapper method for mxnet.NDArray.full.

mrt.sym_utils.nd_zeros(*args, **kwargs)

MRT wrapper method for mxnet.NDArray.zeros.

mrt.sym_utils.nd_ones(*args, **kwargs)

MRT wrapper method for mxnet.NDArray.ones.

mrt.sym_utils.check_graph(symbol, params, logger=<module 'logging' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/logging/__init__.py'>)

Check whether duplicate symbol name exists in a graph.

Also, check input name and params name, and remove unused params name.

Parameters
  • sym (mxnet.symbol) – The input symbol.

  • params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.

Returns

ret – The validated symbol and params.

Return type

tuple

mrt.sym_utils.get_attr(attr, name, default=<object object>)

Check whether duplicate symbol name exists in a graph.

Parameters
  • attr (str) – The input attribute name.

  • name (str) – The input symbol name.

Returns

ret – The attribute value.

Return type

str, int, float, etc.

mrt.sym_utils.get_nd_op(op_name)

Get the MxNet NDArray operator handle of the specified op type.

Typical usage: Calibration, op-level output restore, etc.

Parameters

op_name (str) – The input operator type name.

Returns

ret – The MxNet NDArray operator handle.

Return type

mxnet.NDArray

mrt.sym_utils.get_mxnet_op(op_name)

Get the MxNet symbol operator handle of the specified op type.

Parameters

op_name (str) – The input operator type name.

Returns

ret – The MxNet symbol operator handle.

Return type

mxnet.NDArray

mrt.sym_utils.get_nnvm_op(op_name)

Get the CVM symbol operator handle of the specified op type.

Parameters

op_name (str) – The input operator type name.

Returns

ret – The CVM symbol operator handle.

Return type

cvm.symbol

mrt.sym_utils.sym_iter(sym)

Iterate the symbol and get a list of from the symbol group.

Parameters

sym (mxnet.symbol or cvm.symbol) – The input symbol.

Returns

ret – list of CVM symbols or MxNet symbols.

Return type

list

mrt.sym_utils.nd_const(number, graph, params)

Get the MxNet symbol with respect to the given constant.

Parameters
  • number (float) – The input constant number to be converted.

  • graph (dict) – The symbol name maps to MxNet symbol.

  • params (dict) – The symbol name maps to MxNet NDArray.

Returns

ret – The created MxNet variable symbol.

Return type

mxnet.symbol

mrt.sym_utils.topo_sort(symbol, logger=<module 'logging' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/logging/__init__.py'>, with_deps=False)

Sort all symbols in the mxnet graph in topological order.

Parameters
  • symbol (mxnet.symbol or cvm.symbol) – The input symbol.

  • with_deps (bool) – Whether to return op-level output dict or not, which maps symbol name to a set of output names.

Returns

ret – The Mxnet symbol or CVM symbol; and if with_deps is True, also return operator output dict.

Return type

tuple

mrt.sym_utils.get_entry_id(sym)

Get the entry id of the symbol with respect to its symbol group.

Parameters

sym (mxnet.symbol or cvm.symbol) – The input symbol.

Returns

ret – The entry id.

Return type

int

mrt.sym_utils.get_node(sym, graph)

Get the symbol from the provided graph which has the same name as the given symbol.

Assume all graph node have single output.

Multiple output node will be fused by fuse_multiple_outputs sym_pass.

Parameters
  • sym (mxnet.symbol or cvm.symbol) – The input symbol.

  • graph (dict) – The symbol name maps to MxNet symbol or CVM symbol.

Returns

ret – The Mxnet symbol or CVM symbol.

Return type

mxnet.symbol or cvm.symbol

mrt.sym_utils.topo_visit(symbol, params, inputs_ext, callback, get_op=<function get_mxnet_op>, logger=<module 'logging' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/logging/__init__.py'>, with_maps=False, **kwargs)

MRT topological graph traversal function.

Parameters
  • symbol (mxnet.symbol or cvm.symbol) – The input symbol.

  • params (dict) – MxNet or CVM symbol name (str) maps to mxnet.NDArray or cvm.ndarray.

  • inputs_ext (dict) – Input info name maps to input info value.

  • callback (function) – op-level pass function.

  • get_op (function) – Operator acquisition function

  • with_maps (bool) – After each visit, the corresponding symbol name may change. This flag stands for Whether to return old-new name dict or not.

Returns

ret – The Mxnet symbol or CVM symbol, parameters dict; and if with_maps is True, also return operator output dict, which maps symbol name to .

Return type

tuple

mrt.sym_utils.topo_visit_transformer(symbol, params, callback, get_op=<function get_mxnet_op>, logger=<module 'logging' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/logging/__init__.py'>, **kwargs)

MRT transformer topological graph traversal function.

Parameters
  • symbol (mxnet.symbol or cvm.symbol) – The input symbol.

  • params (dict) – MxNet or CVM symbol name (str) maps to mxnet.NDArray or cvm.ndarray.

  • callback (function) – op-level pass function.

  • get_op (function) – Operator acquisition function

Returns

ret – The Mxnet symbol or CVM symbol, parameters dict.

Return type

tuple