MRT Graph API¶
Contents
mrt.tfm_pass¶
Collection of MRT pass tions. Graph-level symbol helper function for MRT. Stage-level symbol pass designation for MRT.
- mrt.tfm_pass.calculate_ops(symbol, params, normalize=True)¶
Customized graph-level topo pass definition.
Calulate the numbe of operations with respect to the given input shape.
- mrt.tfm_pass.fuse_transpose(symbol, params)¶
Customized graph-level topo pass definition.
Equivalent graph transformation. Fuse or swap the current operator with the former Transpose.
- mrt.tfm_pass.rewrite(symbol, params)¶
Customized graph-level topo pass definition.
Equivalent graph transformation. Rewrite the current symbol into equivalent ones that is feasible for quantization.
- mrt.tfm_pass.quantize(symbol, params, th_dict, precs, scales, op_input_precs, restore_names, shift_bits, softmax_lambd)¶
Customized graph-level topo pass definition.
MRT quantization process function.
The original graph can be denoted as follows:
Case 1. Enable Restore
X1 X2 ... Xn xs1 xs2 ... xsn \ | / \ | / \ | / current_op
where ‘xs1’, ‘xs2’, etc are respectively stand for the scale of the inputs.
The original graph will be restored into:
X1 X2 ... Xn xs1 xs2 ... xsn | | | | | | RX1 RX2 ... RXn 1 1 1 \ | / \ | / \ | / current_op 1 | | ...
where ‘RX1’, ‘RX2’ … stand for restored symbol with unit scale.
Case 2. Other Cases
The original graph will be quantized into:
X1 X2 ... Xn xs1 xs2 ... xsn | | | | | | QX1 QX2 ... QXn \ | / \ | / \ | / [equivalent quantization subgraph] | | ...
For either cases, if the ‘current_op’ belongs to one of the output symbols, it will quantized again with respect to the specified output precision.
- mrt.tfm_pass.prepare_for_compile(symbol, params)¶
Customized graph-level topo pass definition.
Equivalent graph transformation. Convert quantized symbols into feasible ones for compile.
- mrt.tfm_pass.to_cvm(symbol, params)¶
Customized graph-level topo pass definition.
Attribute cast to cvm symbol attributes. Mxnet symbols cast to CVM symbols using cvm-runtime libs.
- mrt.tfm_pass.fuse_multiple_inputs(sym, params)¶
Customized graph-level topo pass definition.
The original graph has multiple inputs.
Inp1 Inp2 ... Inpn | | | | | | ... ... ... ...
Where ‘Inp1’, ‘Inp2’ and so on stands for the input symbol.
The original graph will be transformed into:
data_sum / | \ / | \ / | \ / | \ / | \ / | \ slice_axis slice_axis ... slice_axis axis: 0 axis: 0 axis: 0 begin:0 begin:end1 begin: end(n-1) end:pd1 end:end1+pd2 end:end(n-1)+pdn / | \ | | | reshape reshape reshape shape:shp1 shape:shp2 shape:shpn | | | | | | Inp1 Inp2 ... Inpn | | | | | | ... ... ... ...where ‘data_sum’ is the tiled and concatenated form of all the inputs, and ‘pd1’, ‘pd2’, …, ‘pdn’ is the infer shape product of ‘inp’s, and ‘shp1’, ‘shp2’, …, ‘shpn’ is the infer shape if ‘Inps’.
- mrt.tfm_pass.name_duplicate_check(symbol, params)¶
Check whether duplicate names exist in the graph.
- mrt.tfm_pass.params_unique(symbol, params)¶
Remove duplicate keys params dict.
- mrt.tfm_pass.input_name_replace(symbol, params)¶
Customized graph-level topo pass definition.
Replace the single input name to ‘data’.
- mrt.tfm_pass.fuse_constant(symbol, params)¶
Customized graph-level topo pass definition.
Fix the constant operator shape with respect to the infer shape.
- mrt.tfm_pass.attach_input_shape(symbol, params, input_shapes)¶
Customized graph-level topo pass definition.
Attach the infer shapes for the graph.
- Parameters
symbol (mxnet.symbol) – The graph symbols.
params (dict) – The graph parameters dict.
input_shapes (dict) – The name-shape map.
- Returns
ret – The symbol-param tuple after input shapes attachment.
- Return type
tuple
- mrt.tfm_pass.infer_shape(symbol, params, input_shape=None)¶
Customized graph-level topo pass definition.
Collect the infer shapes from graph.
- Parameters
symbol (mxnet.symbol) – The graph symbols.
params (dict) – The graph parameters dict.
input_shape (tuple) – The input shape of the data.
- Returns
ret – The name-shape map.
- Return type
dict
- mrt.tfm_pass._collect_attribute(op, **kwargs)¶
Collect the attribute value from the specified operator.
- Parameters
op (mxnet.symbol) – The input operator.
- Returns
ret – The output operator.
- Return type
mxnet.symbol
- mrt.tfm_pass.collect_op_names(symbol, params)¶
Customized graph-level topo pass definition.
Collect all kinds of operators that exist in the graph.
- Parameters
symbol (mxnet.symbol) – The graph symbols.
params (dict) – The graph parameters dict.
- Returns
ret – The opname set.
- Return type
set
- mrt.tfm_pass.fuse_multiple_outputs(symbol, params)¶
Customized symbol-level topo pass definition.
Symbol-level multiple-outputs-fusion pass.
X | | SliceChannel / | \ / | \ / | \ / | \ A1 A2 ... AnThe original graph will be transformed into:
X / | \ / | \ / | \ / | \ / | \ / | \ slice_axis slice_axis ... slice_axis axis: ich axis: ich axis: ich begin:0 begin:stride begin: (n-1)*stride end:stride end:2*stride end:n*stride / | \ A1 A2 ... Anwhere ‘ich’ is the attribute ‘axis’ along which to split, and
\[stride = dim / num\_outputs\]where ‘xshape’ is infer shape of ‘X’ and ‘num_ouputs’ is the number of splits.
- mrt.tfm_pass._get_opt(out, lambd)¶
Calibrate the MRT model after setting mrt data.
\[mean\_v = mean(out)\]\[n = product(shape(out))\]\[sqrt\_n = sqrt(n)\]\[std = norm(out-mean) / sqrt\_n\]\[alpha = |mean| + lambd * std\]\[absmax = ||out||_2\]Case 1. alpha < 0.95 * absmax
opt = alpha
Case 2. Other Cases
out = absmax
- Parameters
out (nd.NDArray) – The graph level output.
lambd (float) – Hyperparameter that set the alpha of data.
- Returns
opt – The opt value.
- Return type
float
- mrt.tfm_pass.sym_calibrate(symbol, params, data, **kwargs)¶
Customized graph-level topo pass definition.
Calibrate the MRT model after setting mrt data.
- Parameters
symbol (mxnet.symbol) – The graph symbols.
params (dict) – The graph parameters dict.
data (nd.NDArray) – The input data.
- Returns
ret – The threshold dict after calibration.
- Return type
dict
- mrt.tfm_pass.convert_params_dtype(params, src_dtypes=['float32', 'float64'], dest_dtype='float64')¶
Convert the source data type into to target data type.
- Parameters
params (nd.NDArray) – The input data to be converted.
src_dtypes (str or list of str) – The source data type(s)
- Returns
ret – The precision of the input.
- Return type
int
mrt.sym_utils¶
Collection of MRT pass tool functions. Topo sequence helper function for MxNet graph parsing. Simplification of MRT graph traversal.
Only key part of util functions are elaborated.
- mrt.sym_utils.is_op(sym, params)¶
Judge whether the input symbol is an operator.
- Parameters
sym (mxnet.symbol) – The input symbol.
params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.
- Returns
ret – Return False if the input symbol is not an operator, else True.
- Return type
bool
- mrt.sym_utils.is_var(sym, params)¶
Judge whether the input symbol is a variable.
- Parameters
sym (mxnet.symbol) – The input symbol.
params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.
- Returns
ret – Return False if the input symbol is not a variable, else True.
- Return type
bool
- mrt.sym_utils.is_params(sym, params)¶
Judge whether the input symbol is a parameter.
- Parameters
sym (mxnet.symbol) – The input symbol.
params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.
- Returns
ret – Return False if the input symbol is not a parameter, else True.
- Return type
bool
- mrt.sym_utils.is_inputs(sym, params)¶
Judge whether the input symbol is an input.
- Parameters
sym (mxnet.symbol) – The input symbol.
params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.
- Returns
ret – Return False if the input symbol is not an input, else True.
- Return type
bool
- mrt.sym_utils.nd_array(source_array, ctx=None, dtype='float64')¶
Convert the source array into MxNet NDArray.
- Parameters
source_array (tuple or list) – The input array to be converted.
ctx (mxnet.context) – The context on which to store the converted array.
dtype (str) – The entry data type of the converted array.
- Returns
ret – The converted result.
- Return type
mxnet.NDArray
- mrt.sym_utils.nd_arange(*args, **kwargs)¶
MRT wrapper method for mxnet.NDArray.arange.
- mrt.sym_utils.nd_full(*args, **kwargs)¶
MRT wrapper method for mxnet.NDArray.full.
- mrt.sym_utils.nd_zeros(*args, **kwargs)¶
MRT wrapper method for mxnet.NDArray.zeros.
- mrt.sym_utils.nd_ones(*args, **kwargs)¶
MRT wrapper method for mxnet.NDArray.ones.
- mrt.sym_utils.check_graph(symbol, params, logger=<module 'logging' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/logging/__init__.py'>)¶
Check whether duplicate symbol name exists in a graph.
Also, check input name and params name, and remove unused params name.
- Parameters
sym (mxnet.symbol) – The input symbol.
params (dict) – MxNet symbol name (str) maps to mxnet.NDArray.
- Returns
ret – The validated symbol and params.
- Return type
tuple
- mrt.sym_utils.get_attr(attr, name, default=<object object>)¶
Check whether duplicate symbol name exists in a graph.
- Parameters
attr (str) – The input attribute name.
name (str) – The input symbol name.
- Returns
ret – The attribute value.
- Return type
str, int, float, etc.
- mrt.sym_utils.get_nd_op(op_name)¶
Get the MxNet NDArray operator handle of the specified op type.
Typical usage: Calibration, op-level output restore, etc.
- Parameters
op_name (str) – The input operator type name.
- Returns
ret – The MxNet NDArray operator handle.
- Return type
mxnet.NDArray
- mrt.sym_utils.get_mxnet_op(op_name)¶
Get the MxNet symbol operator handle of the specified op type.
- Parameters
op_name (str) – The input operator type name.
- Returns
ret – The MxNet symbol operator handle.
- Return type
mxnet.NDArray
- mrt.sym_utils.get_nnvm_op(op_name)¶
Get the CVM symbol operator handle of the specified op type.
- Parameters
op_name (str) – The input operator type name.
- Returns
ret – The CVM symbol operator handle.
- Return type
cvm.symbol
- mrt.sym_utils.sym_iter(sym)¶
Iterate the symbol and get a list of from the symbol group.
- Parameters
sym (mxnet.symbol or cvm.symbol) – The input symbol.
- Returns
ret – list of CVM symbols or MxNet symbols.
- Return type
list
- mrt.sym_utils.nd_const(number, graph, params)¶
Get the MxNet symbol with respect to the given constant.
- Parameters
number (float) – The input constant number to be converted.
graph (dict) – The symbol name maps to MxNet symbol.
params (dict) – The symbol name maps to MxNet NDArray.
- Returns
ret – The created MxNet variable symbol.
- Return type
mxnet.symbol
- mrt.sym_utils.topo_sort(symbol, logger=<module 'logging' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/logging/__init__.py'>, with_deps=False)¶
Sort all symbols in the mxnet graph in topological order.
- Parameters
symbol (mxnet.symbol or cvm.symbol) – The input symbol.
with_deps (bool) – Whether to return op-level output dict or not, which maps symbol name to a set of output names.
- Returns
ret – The Mxnet symbol or CVM symbol; and if with_deps is True, also return operator output dict.
- Return type
tuple
- mrt.sym_utils.get_entry_id(sym)¶
Get the entry id of the symbol with respect to its symbol group.
- Parameters
sym (mxnet.symbol or cvm.symbol) – The input symbol.
- Returns
ret – The entry id.
- Return type
int
- mrt.sym_utils.get_node(sym, graph)¶
Get the symbol from the provided graph which has the same name as the given symbol.
Assume all graph node have single output.
Multiple output node will be fused by fuse_multiple_outputs sym_pass.
- Parameters
sym (mxnet.symbol or cvm.symbol) – The input symbol.
graph (dict) – The symbol name maps to MxNet symbol or CVM symbol.
- Returns
ret – The Mxnet symbol or CVM symbol.
- Return type
mxnet.symbol or cvm.symbol
- mrt.sym_utils.topo_visit(symbol, params, inputs_ext, callback, get_op=<function get_mxnet_op>, logger=<module 'logging' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/logging/__init__.py'>, with_maps=False, **kwargs)¶
MRT topological graph traversal function.
- Parameters
symbol (mxnet.symbol or cvm.symbol) – The input symbol.
params (dict) – MxNet or CVM symbol name (str) maps to mxnet.NDArray or cvm.ndarray.
inputs_ext (dict) – Input info name maps to input info value.
callback (function) – op-level pass function.
get_op (function) – Operator acquisition function
with_maps (bool) – After each visit, the corresponding symbol name may change. This flag stands for Whether to return old-new name dict or not.
- Returns
ret – The Mxnet symbol or CVM symbol, parameters dict; and if with_maps is True, also return operator output dict, which maps symbol name to .
- Return type
tuple
- mrt.sym_utils.topo_visit_transformer(symbol, params, callback, get_op=<function get_mxnet_op>, logger=<module 'logging' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/logging/__init__.py'>, **kwargs)¶
MRT transformer topological graph traversal function.
- Parameters
symbol (mxnet.symbol or cvm.symbol) – The input symbol.
params (dict) – MxNet or CVM symbol name (str) maps to mxnet.NDArray or cvm.ndarray.
callback (function) – op-level pass function.
get_op (function) – Operator acquisition function
- Returns
ret – The Mxnet symbol or CVM symbol, parameters dict.
- Return type
tuple