MRT Operator API¶
mrt.tfm_base¶
Customized Symbolic Pass Interfaces. Base passes with default operation settings. Collection of transformer management functions.
- class mrt.tfm_base.Transformer¶
Base transformer object
- All subclass inherited from this should be registered maually
using helper function register_transformer, and then all class function should be well-considered to override or use helper function register_pass to annotate using function defined in base class (that is this object), if there’s no point to redefine duplicate function.
- Subclass should only implement function defined in base object,
and we advise any helper function to be named with underline prefix.
- Please refer to file tfm_ops.py for more examples about
operator transformers.
- calculate_ops(op, **kwargs)¶
Calculate the amount of computations for operator.
Returns the output size by default.
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- fuse_transpose(op, **kwargs)¶
Equivalent graph tranposition.
In case that at least one of the two adjacent ops is Transpose, the other op may either be swappable or fusable with Transpose.
Do nothing by default.
- op_name = 'none'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- prepare_for_compile(op, **kwargs)¶
Equivalent graph transition may be needed before compile dynamic shape fixxation for MulScalar, DivScalar, Zeroslike and ‘OnesLike’ that is only needed in quantization: Do nothing by default.
- quantize(op, **kwargs)¶
Main procedure for quantization.
Do nothing by default.
- rewrite(op, **kwargs)¶
- Operators may need to rewrite to equivalent graph which is
easier to quantize for later procedure.
Do nothing by default.
- validate(op, **kwargs)¶
- All operators should be validated before another pass,
neither correcting the invalid format nor asserting error to announce unsupported graph.
Do nothing by default.
mrt.tfm_ops¶
Op-level realization of Model Representation Tool. Implementation of validation, equivalent transformation, quantization, transpose fusion, ops calculation, preparation and compilation.
Only crucial parts of the custommized pass implementation are elaborated.
MxNet Supported Operators are listed as below:
NN Operators
Transform Operators
Mathematic Operators
Arrange Operators
Broadcast Operators
Elemwise Operators
Scalar Operators
Reduce Operators
Custom Operators
Vision Operators
- class mrt.tfm_ops.Null¶
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
Transform the input data.
- class mrt.tfm_ops.Transpose¶
- fuse_transpose(op, **kwargs)¶
Customized fuse_transpose pass Introduction.
For continuous transpose sequence like:
\[Z = \text{Transpose}(Y, axes_2)\]\[Y = \text{Transpose}(X, axes_1)\]Exert equivalent transformation on the adjacent two transpose ops:
\[ \begin{align}\begin{aligned}\begin{split}Z(j) = \text{Transpose}(Y, axes_2) \\ = \text{Transpose}(\text{Tranpose}(X, axes_1), axes_2) \\ = \text{Transpose}(X, axes_3), \\\end{split}\\\text{where } axes_3(j) = axes_1(axes_2(j))\end{aligned}\end{align} \]
- class mrt.tfm_ops.Relu¶
- fuse_transpose(op, **kwargs)¶
Customized fuse_transpose pass Introduction.
See
mrt.tfm_ops.reverse_transposefor reference.
- class mrt.tfm_ops.LeakyReLU¶
- fuse_transpose(op, **kwargs)¶
Customized fuse_transpose pass Introduction.
See
mrt.tfm_ops.reverse_transposefor reference.
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
LeakyReLU can be equivalently transformed to be supported by cvm.
\[LeakyReLU(X) = relu(X) - slope*relu(-X)\]
- validate(op, **kwargs)¶
Customized validate pass Introduction.
The activation function only support leaky.
- class mrt.tfm_ops.MulScalar¶
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
Transform into broadcast_mul.
- class mrt.tfm_ops.DivScalar¶
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
Transform into broadcast_mul.
- class mrt.tfm_ops.Activation¶
- rewrite(op, **kwargs)¶
Equivalent transform of rewrite operator Only applies when the attribute act_type equals to relu or sigmoid, which indicates that rewrite could be directly tranformed into the corresponding operator.
- validate(op, **kwargs)¶
Customized validate pass Introduction.
The activation function only support relu.
- class mrt.tfm_ops.Convolution¶
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_xwbfor reference
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
Input node with layout NCW is equivalently rewriten into layout NCHW The parameters if attached dimension is set by default
- class mrt.tfm_ops.Pad¶
- compile(op, **kwargs)¶
Customized compile pass Introduction.
Only support constant padding type.
- class mrt.tfm_ops.ExpandDims¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'expand_dims'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Embedding¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'Embedding'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
The input array to the embedding operator should be scaled to 1.
- class mrt.tfm_ops.Repeat¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'repeat'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.BoxNms¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = '_contrib_box_nms'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.SliceLike¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'slice_like'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_scalefor reference.
- class mrt.tfm_ops.SliceAxis¶
- op_name = 'slice_axis'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
‘SliceAxis’ can be equivalently transformed into ‘Slice’ which is supported by cvm.
- validate(op, **kwargs)¶
- All operators should be validated before another pass,
neither correcting the invalid format nor asserting error to announce unsupported graph.
Do nothing by default.
- class mrt.tfm_ops.SliceChannel¶
- op_name = 'SliceChannel'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.UpSampling¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'UpSampling'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- validate(op, **kwargs)¶
- All operators should be validated before another pass,
neither correcting the invalid format nor asserting error to announce unsupported graph.
Do nothing by default.
- class mrt.tfm_ops.FullyConnected¶
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_xwbfor reference
- reduce(op, **kwargs)¶
Dimension reduction function considering both flatten cases.
Denote the input as X and transformed operator as Y. If flatten is true, only one reduction of the high dimension input to 2 dimension is needed.
\[RX = reshape(X) Y = FullyConnected(RX)\]If flatten is false, firstly one reduction of the input to 2 dimension is needed. After FullyConnected op, the ouput should be reshaped to the correct output shape.
\[RX = reshape(X) out = FullyConnected(RX) Y = reshape(out)\]
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
Using matrix decomposition to avoid overflow.
\[Y = B + X*W^T = B + X1*W1^T + X2*W2^T + ...\]\[Wi.shape = (numHidden, step), W = [W1, W2, ...]\]\[Xi.shape = (batchSize, step), X = [X1, X2, ...]\]
- class mrt.tfm_ops.Sigmoid¶
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_tablefor reference
- class mrt.tfm_ops.Exp¶
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_tablefor reference
- class mrt.tfm_ops.Softmax¶
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
Step 1. Requant Input
\[Xq, xprec, xs = requant\_operator(X, iprec, oscale)\]Step 2. Calculate Norm Value
First, calculate the bias with respect to unscaled input (denoted as ‘alpha’) as:
\[alpha = int(lambd * xs)\]where ‘lambd’ stands for a hyperparameter configured by user.
Then, calculate offset value with respect to each axis, and clip:
\[max\_axis = max(Xq, axis)\]\[offset = broadcast\_mul(max\_axis - var)\]\[offset\_c = clip(norm, xprec)\]Next, calculate norm and clip:
\[norm = relu(Xq - offset)\]\[norm = clip(norm, xprec)\]Step 3. Create Lookup Table
\[dim = alpha + 1\]\[data = range(0, dim)\]\[table = exp(data / xs)\]\[table\_prec = get_bit(exp(||table||_2))\]for reference of ‘get_bit’, see
mrt.tfm_utils.get_bit.\[table\_c = clip(table, min=0, max=get_range(table\_prec))\]for reference of ‘get_range’, see
mrt.tfm_utils.get_range.\[weight = reshape(round(table\_c), (dim, 1))\]Step 4. Get Lookup value
The cvm customized operator ‘cvm_lut’ has been adopted.
\[lut = cvm\_lut(norm, weight, dim)\]Step 5. Get Output
\[sum_lut = sum(lut, axis)\]\[oprec = min(15, 31-tprec)\]\[oscale = get_range(oprec)\]\[prob = lut * oscale\]\[half_lut = realize(sum_lut, 1, 31)\]for reference of ‘realize’, see
mrt.tfm_utils.realize.\[prob\_b = prob + half_lut\]\[prob = prob\_b / sum_lut\]
- class mrt.tfm_ops.Pooling¶
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
Case 1. ‘pool_type’ is ‘avg’ and ‘global_pool’ is True
\[scale\_sym = 1 / (xshp[2] * xshp[3])\]where ‘xshp’ is the infer shape of the input ‘X’.
\[op\_s = sum(x, axis=(2, 3))\]\[op = op\_s * scale\_sym\]Case 2. ‘pool_type’ is ‘avg’ and ‘global_pool’ is False
conv_attr = { 'no_bias': 'True', 'dilate': '(1, 1)', 'kernel': kernel, 'stride': stride, 'pad': pad, 'layout': 'NCHW', 'num_filter': xshp[1], 'num_group': xshp[1], }
where ‘kernel’ is the pooling kernel size, ‘stride’ is the stride for pooling, ‘pad’ is the pad for pooling.
The ‘Activation’ operator could be converted into ‘Convolution’. First, set up the attributes:
\[W = full(shape=wshp, val=1/product(kernel))\]\[op = Convolution(X, W, conv\_attr)\]
- validate(op, **kwargs)¶
Customized validate pass Introduction.
The ‘layout’ only support ‘NCHW’.
The ‘pool_type’ only support ‘max’ and ‘avg’. And if ‘pool_type’ is ‘avg’ and ‘pooling_convention’ is ‘full’, then ‘global_pool’ must be True. And if ‘pool_type’ is ‘avg’ and ‘pooling_convention’ is not ‘full’, then ‘pooling_convention’ must be ‘valid’ and ‘global_pool’ must be True.
The ‘count_include_pad’ must be True.
- class mrt.tfm_ops.BroadcastMul¶
- prepare_for_compile(op, **kwargs)¶
Customized prepare_for_compile pass Introduction.
If either one of the input equals to zero, the op can be merged into zero.
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
\[Xq, xprec, xs = requant(X, oprec)\]\[Bq, bprec, bs = requant(B, oprec)\]where ‘oprec’ stands for the default quantization precision for ‘BroadcastMul’.
See
mrt.tfm_utils.requantfor reference.\[op = Xq * Bq\]The infer precision equals to ‘xprec’ plus ‘bprec’.
- class mrt.tfm_ops.BroadcastAdd¶
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_scalefor reference.
- class mrt.tfm_ops.BroadcastSub¶
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_scalefor reference.
- class mrt.tfm_ops.BroadcastTo¶
- op_name = 'broadcast_to'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.BroadcastGreater¶
- op_name = 'broadcast_greater'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Concat¶
- fuse_transpose(op, **kwargs)¶
Customized fuse_transpose pass Introduction.
Suppose the inputs are all ‘Transpose’ about ‘axis’:
cA cB .. cC | | | | | | Transpose Transpose Transpose (axis) (axis) .. (axis) \ | / \ | / \ | / \ | / Concatthen, the graph can be transformed into:
cA cB .. cC \ | / \ | / \ | / Concat | Transpose (axis)where ‘Transpose’ here is also about ‘axis’.
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_scalefor reference.
- class mrt.tfm_ops.Sum¶
- fuse_transpose(op, **kwargs)¶
Equivalent graph tranposition.
In case that at least one of the two adjacent ops is Transpose, the other op may either be swappable or fusable with Transpose.
Do nothing by default.
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
\[Xq, xprec, xs = requant(X, oprec)\]where ‘oprec’ stands for the default quantization precision for ‘Sum’.
See
mrt.tfm_utils.requantfor reference.\[k = int(product(ishp) / product(oshp))\]\[kprec = get\_bit\_cnt(k)\]where ‘ishp’ and ‘oshp’ respectively stands for the input shape and output shape.
See
mrt.tfm_utils.get_bit_cntfor reference.The infer precision equals to ‘xprec’ plus ‘kprec’.
- class mrt.tfm_ops.BatchNorm¶
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
Case 1. Convolution Input
When the graph looks like this:
X W B \ | / Convolution gamma beta data_mean data_var | | | | | | | | | | -------------------------------- | BatchNorm\[sc = gamma / sqrt(data_var + eps)\]\[weight = W * sc.reshape(sc.shape, 1, 1, 1)\]\[bias = beta - sc * data\_mean + B\]\[op = Convolution(X, weight, bias, conv_attr)\]where ‘conv_attr’ is the attribute of ‘X’.
Case 2. Other Cases
X gamma beta data_mean data_var | | | | | | | | | | ---------------------------- | BatchNormrshp = [s if i == axis else 1 for i, s in enumerate(ishp)]
where ‘axis’ is attribute of operator and ishp is the input shape of ‘X’
\[weight = reshape(rc, rshp)\]\[bias = reshape(beta - sc * data\_mean, rshp)\]\[op = X * weight + bias\]
- class mrt.tfm_ops.Flatten¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'Flatten'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Floor¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'floor'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Ceil¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'ceil'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Round¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'round'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Fix¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'fix'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Cast¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'Cast'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Slice¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'slice'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Reshape¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'Reshape'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Custom¶
- validate(op, **kwargs)¶
Customized validate pass Introduction.
The op type only support ‘cvm_clip’, cvm_left_shift, ‘cvm_right_shift’, ‘cvm_lut’.
- class mrt.tfm_ops.Clip¶
- fuse_transpose(op, **kwargs)¶
Customized fuse_transpose pass Introduction.
See
mrt.tfm_ops.reverse_transposefor reference.
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
\[amin = int(a\_min * iscale)\]\[amax = int(a\_max * iscale)\]where ‘a_min’ and ‘a_max’ are attributes of op, and ‘iscale’ is the input scale of the input.
\[op = clip(x, amin, amax)\]
- class mrt.tfm_ops.Minimum¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = '_minimum'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Maximum¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = '_maximum'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Max¶
- op_name = 'max'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Min¶
- op_name = 'min'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Argmax¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'argmax'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Argmin¶
- compile(op, **kwargs)¶
Compile mxnet symbol into nnvm symbol.
Throw exception by default.
- op_name = 'argmin'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Abs¶
- op_name = 'abs'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.ElemwiseAdd¶
- fuse_transpose(op, **kwargs)¶
Customized fuse_transpose pass Introduction.
See
mrt.tfm_ops._quantize_scalefor reference.
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_scalefor reference.
- class mrt.tfm_ops.ElemwiseSub¶
- fuse_transpose(op, **kwargs)¶
Customized fuse_transpose pass Introduction.
See
mrt.tfm_ops._quantize_scalefor reference.
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
See
mrt.tfm_ops._quantize_scalefor reference.
- class mrt.tfm_ops.ElemwiseMul¶
- rewrite(op, **kwargs)¶
validate the infer_shapes of lhs and rhs must be the same thus this op could be rewrite into broadcast_mul corresponding cvm op would be optimized at compile time
- class mrt.tfm_ops.Dropout¶
- fuse_transpose(op, **kwargs)¶
Customized fuse_transpose pass Introduction.
See
mrt.tfm_ops.reverse_transposefor reference.
- class mrt.tfm_ops.Arange¶
- op_name = '_arange'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Tile¶
- op_name = 'tile'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Negative¶
- op_name = 'negative'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.SwapAxis¶
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
\[ndims = len(ishp)\]where ‘ishp’ is the infer shape of the input.
\[new_axis = range(ndims)\]where ‘dim1’ and ‘dim2’ is the attributes of op.
- class mrt.tfm_ops.PlusScalar¶
- op_name = '_plus_scalar'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
Transform into broadcast_add.
- class mrt.tfm_ops.ZerosLike¶
- rewrite(op, **kwargs)¶
Customized quantize pass Introduction.
Make constant zeros with fixed shape.
- class mrt.tfm_ops.OnesLike¶
- rewrite(op, **kwargs)¶
Customized quantize pass Introduction.
Make constant ones with fixed shape.
- class mrt.tfm_ops.GreaterScalar¶
- validate(op, **kwargs)¶
Customized validate pass Introduction.
Only support integer scalar.
- class mrt.tfm_ops.Where¶
- op_name = 'where'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.Squeeze¶
- op_name = 'squeeze'¶
Transformer Operator Name
- Transformer is associated with operator which is defined
in mxnet, and the variable indicates the type name of mxnet symbol.
- Attention please, the base transformer should not be instantiated
since it’s just an abstarct aggregation of graph pass, and it’s named none by default.
- class mrt.tfm_ops.BatchDot¶
- quantize(op, **kwargs)¶
Customized quantize pass Introduction.
The inputs are quantized into the same precision level.
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
Using matrix decomposition to avoid overflow.
\[Y = A \cdot B = A_1 \cdot B_1 + A_2 \cdot B_2 + ...\]where
\[A_i\text{.shape} = (batch, M, step)\]\[B_i\text{.shape} = (batch, step, N)\]
- class mrt.tfm_ops.BroadcastLike¶
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
The original operator:
\[op = broadcast\_like(X, W, lhs\_axes, rhs\_axes)\]can be transformed in threee conditions.
Case 1. Null Attributes
\[mul = broadcast\_mul(W, 0)\]\[add = broadcast\_add(mul, 1)\]\[op = broadcast\_mul(X, add)\]Case 2. Batch Axis not in Attributes
Calculate attributes for tile as follows:
cnts = {v: wshp[rhs_axes[i]] for i, v in enumerate(lhs_axes)} reps = tuple([cnts[v] if v in lhs_axes else 1 for v in range(xndims)])
where wshp is the shape of input weight.
\[op = tile(X, reps=reps)\]Case 3. Other Cases
In this case, we only support injection of batchaxis from lhs_axes to rhs_axes.
After transformation of weight (see source code for ref), result was calculated as in case 1.
- class mrt.tfm_ops.ReshapeLike¶
- rewrite(op, **kwargs)¶
Customized rewrite pass Introduction.
Equivalently transformed into reshape. Since dynamic shape fusion is supported, only single batch axis is supported rather than multiple ones.
- mrt.tfm_ops._quantize_scale(op, **kwargs)¶
quantization function with the inputs form of:
\[Y = f(node1, node2, node3, ...)\]where node1, node2, node3, … ought to be quantized to the same scale.
The infer precision depends on the input with the maximum quantized tight precision.
- mrt.tfm_ops._quantize_xwb(op, **kwargs)¶
quantization function with the inputs form of:
\[Y = X*W + B\]The input and weight are quantized into the same precision level. Bias is quantized with respect to the product of input and weight.
the infer precision equals to the sum of quantized input precision, quantized weight precision and the product precision.
- mrt.tfm_ops._quantize_table(op, **kwargs)¶
quantization function with the inputs form of:
\[Y = f(X) = g(exp(X))\]Step 1. Requant Input
\[Xq, xprec, xs = requant\_operator(X, iprec, oscale)\]where ‘Xq’ is quantized symbol, ‘xprec’ and ‘xs’ stand for quantized precision and scale, for reference of ‘requant_operator’, see
mrt.tfm_utils.requant_operator.\[offset = BroadcastAdd(Xq, alpha)\]where alpha is the threshold of ‘Xq’
Step 2. Create Lookup Table
\[r = range(-alpha, alpha+1) / xs\]where ‘r’ stands for the table range of the unscaled input.
\[out = f(r)\]\[oscale = scale(||out||_2, xprec)\]for reference of ‘scale’, see
mrt.tfm_utils.scale.\[in\_dim = 2*alpha+1\]\[out\_q = round(out * oscale)\]Util now, the float-simulating-int process has been accomplished.
\[W = reshape(round(out_q * oscale), in\_dim)\]Util now, the lookup table has been created.
Step 3. Get Output
The cvm customized operator ‘cvm_lut’ has been adopted.
\[op = cvm\_lut(X, W, in\_dim)\]
- mrt.tfm_ops.reverse_sequence(op)¶
Reverse the symbol sequenze may leads to error of the different result, due to the graph unequaivent transformer.
Example
A -> B -> C |-> D -> E
after reverse sequence is
B -> A -> C |-> D -> Ewhich is invalid.
- Notice:
The fuse_transpose pass have the same hidden problems.
- mrt.tfm_ops.reverse_transpose(op)¶
For symbol with single Transpose input, reverse these sequence if this two op is swapable.
X -> Transpose -> op
after reverse sequence is
X -> op -> Transpose
- Notice:
After and before swap the axis of the Transpose remains the same.
mrt.cvm_op¶
Customized Op-level realization of MxNet Forward Computing Framework.
MxNet customized operator property class.
Only crucial parts of the custommized forward operator implementation are elaborated.
- class mrt.cvm_op.Clip(precision, **kwargs)¶
- forward(is_train, req, in_data, out_data, aux)¶
MxNet customized operator forward implementation.
Clip the input within [-2^prec, 2^prec].
\[rnd = round(X)\]where X is the input tensor.
\[out = clip(rnd, -2^{prec}, 2^{prec})\]where prec is an operator attribute representing integer bits.
- class mrt.cvm_op.LeftShift(precision, shift_bit, **kwargs)¶
- forward(is_train, req, in_data, out_data, aux)¶
MxNet customized operator forward implementation.
Left shift data for sb bits and clip within [-2^prec, 2^prec].
\[rnd = round(X)\]where X is the input tensor.
\[val = rnd * 2^{sb}\]\[out = clip(val1, -2^{prec}, 2^{prec})\]where prec is an operator attribute representing integer bits, sb represents the left shift bits.
- class mrt.cvm_op.RightShift(precision, shift_bit, **kwargs)¶
- forward(is_train, req, in_data, out_data, aux)¶
MxNet customized operator forward implementation.
Right shift data for sb bits and clip within [-2^prec, 2^prec].
\[rnd = round(X)\]where X is the input tensor.
\[val0 = floor(rnd / 2^{sb-1})\]\[val1 = floot((val0+1) / 2)\]\[out = clip(val1, -2^{prec}, 2^{prec})\]where prec is an operator attribute representing integer bits, sb represents the right shift bits.
- class mrt.cvm_op.LUT(in_dim, **kwargs)¶
- forward(is_train, req, in_data, out_data, aux)¶
MxNet customized operator forward implementation.
Embed X with respect to T with vocabulary size of indim. The dimension of the embedding vectors is 1.
where X is the input tensor and T is the weight tensor.
\[val = Embedding(X, T, indim, 1)\]\[out = squeeze(val, axis=-1)\]where indim is an operator attribute representing vocabulary size of input indices.
- class mrt.cvm_op.ClipProp(precision=8, shift_bit=0)¶
MxNet cvm_clip operator property class.
- class mrt.cvm_op.LeftShiftProp(precision=8, shift_bit=0)¶
MxNet cvm_left_shift operator property class.
- class mrt.cvm_op.RightShiftProp(precision=8, shift_bit=0)¶
MxNet cvm_right_shift operator property class.
- class mrt.cvm_op.LUTProp(in_dim)¶
MxNet cvm_lut operator property class.