MRT Operator API

mrt.tfm_base

Customized Symbolic Pass Interfaces. Base passes with default operation settings. Collection of transformer management functions.

class mrt.tfm_base.Transformer

Base transformer object

All subclass inherited from this should be registered maually

using helper function register_transformer, and then all class function should be well-considered to override or use helper function register_pass to annotate using function defined in base class (that is this object), if there’s no point to redefine duplicate function.

Subclass should only implement function defined in base object,

and we advise any helper function to be named with underline prefix.

Please refer to file tfm_ops.py for more examples about

operator transformers.

calculate_ops(op, **kwargs)

Calculate the amount of computations for operator.

Returns the output size by default.

compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

fuse_transpose(op, **kwargs)

Equivalent graph tranposition.

In case that at least one of the two adjacent ops is Transpose, the other op may either be swappable or fusable with Transpose.

Do nothing by default.

op_name = 'none'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

prepare_for_compile(op, **kwargs)

Equivalent graph transition may be needed before compile dynamic shape fixxation for MulScalar, DivScalar, Zeroslike and ‘OnesLike’ that is only needed in quantization: Do nothing by default.

quantize(op, **kwargs)

Main procedure for quantization.

Do nothing by default.

rewrite(op, **kwargs)
Operators may need to rewrite to equivalent graph which is

easier to quantize for later procedure.

Do nothing by default.

validate(op, **kwargs)
All operators should be validated before another pass,

neither correcting the invalid format nor asserting error to announce unsupported graph.

Do nothing by default.

mrt.tfm_ops

Op-level realization of Model Representation Tool. Implementation of validation, equivalent transformation, quantization, transpose fusion, ops calculation, preparation and compilation.

Only crucial parts of the custommized pass implementation are elaborated.

MxNet Supported Operators are listed as below:

class mrt.tfm_ops.Null
quantize(op, **kwargs)

Customized quantize pass Introduction.

Transform the input data.

class mrt.tfm_ops.Transpose
fuse_transpose(op, **kwargs)

Customized fuse_transpose pass Introduction.

For continuous transpose sequence like:

\[Z = \text{Transpose}(Y, axes_2)\]
\[Y = \text{Transpose}(X, axes_1)\]

Exert equivalent transformation on the adjacent two transpose ops:

\[ \begin{align}\begin{aligned}\begin{split}Z(j) = \text{Transpose}(Y, axes_2) \\ = \text{Transpose}(\text{Tranpose}(X, axes_1), axes_2) \\ = \text{Transpose}(X, axes_3), \\\end{split}\\\text{where } axes_3(j) = axes_1(axes_2(j))\end{aligned}\end{align} \]
class mrt.tfm_ops.Relu
fuse_transpose(op, **kwargs)

Customized fuse_transpose pass Introduction.

See mrt.tfm_ops.reverse_transpose for reference.

class mrt.tfm_ops.LeakyReLU
fuse_transpose(op, **kwargs)

Customized fuse_transpose pass Introduction.

See mrt.tfm_ops.reverse_transpose for reference.

rewrite(op, **kwargs)

Customized rewrite pass Introduction.

LeakyReLU can be equivalently transformed to be supported by cvm.

\[LeakyReLU(X) = relu(X) - slope*relu(-X)\]
validate(op, **kwargs)

Customized validate pass Introduction.

The activation function only support leaky.

class mrt.tfm_ops.MulScalar
rewrite(op, **kwargs)

Customized rewrite pass Introduction.

Transform into broadcast_mul.

class mrt.tfm_ops.DivScalar
rewrite(op, **kwargs)

Customized rewrite pass Introduction.

Transform into broadcast_mul.

class mrt.tfm_ops.Activation
rewrite(op, **kwargs)

Equivalent transform of rewrite operator Only applies when the attribute act_type equals to relu or sigmoid, which indicates that rewrite could be directly tranformed into the corresponding operator.

validate(op, **kwargs)

Customized validate pass Introduction.

The activation function only support relu.

class mrt.tfm_ops.Convolution
quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_xwb for reference

rewrite(op, **kwargs)

Customized rewrite pass Introduction.

Input node with layout NCW is equivalently rewriten into layout NCHW The parameters if attached dimension is set by default

class mrt.tfm_ops.Pad
compile(op, **kwargs)

Customized compile pass Introduction.

Only support constant padding type.

class mrt.tfm_ops.ExpandDims
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'expand_dims'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Embedding
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'Embedding'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

quantize(op, **kwargs)

Customized quantize pass Introduction.

The input array to the embedding operator should be scaled to 1.

class mrt.tfm_ops.Repeat
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'repeat'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.BoxNms
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = '_contrib_box_nms'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.SliceLike
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'slice_like'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_scale for reference.

class mrt.tfm_ops.SliceAxis
op_name = 'slice_axis'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

rewrite(op, **kwargs)

Customized rewrite pass Introduction.

‘SliceAxis’ can be equivalently transformed into ‘Slice’ which is supported by cvm.

validate(op, **kwargs)
All operators should be validated before another pass,

neither correcting the invalid format nor asserting error to announce unsupported graph.

Do nothing by default.

class mrt.tfm_ops.SliceChannel
op_name = 'SliceChannel'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.UpSampling
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'UpSampling'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

validate(op, **kwargs)
All operators should be validated before another pass,

neither correcting the invalid format nor asserting error to announce unsupported graph.

Do nothing by default.

class mrt.tfm_ops.FullyConnected
quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_xwb for reference

reduce(op, **kwargs)

Dimension reduction function considering both flatten cases.

Denote the input as X and transformed operator as Y. If flatten is true, only one reduction of the high dimension input to 2 dimension is needed.

\[RX = reshape(X) Y = FullyConnected(RX)\]

If flatten is false, firstly one reduction of the input to 2 dimension is needed. After FullyConnected op, the ouput should be reshaped to the correct output shape.

\[RX = reshape(X) out = FullyConnected(RX) Y = reshape(out)\]
rewrite(op, **kwargs)

Customized rewrite pass Introduction.

Using matrix decomposition to avoid overflow.

\[Y = B + X*W^T = B + X1*W1^T + X2*W2^T + ...\]
\[Wi.shape = (numHidden, step), W = [W1, W2, ...]\]
\[Xi.shape = (batchSize, step), X = [X1, X2, ...]\]
class mrt.tfm_ops.Sigmoid
quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_table for reference

class mrt.tfm_ops.Exp
quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_table for reference

class mrt.tfm_ops.Softmax
quantize(op, **kwargs)

Customized quantize pass Introduction.

Step 1. Requant Input

\[Xq, xprec, xs = requant\_operator(X, iprec, oscale)\]

Step 2. Calculate Norm Value

First, calculate the bias with respect to unscaled input (denoted as ‘alpha’) as:

\[alpha = int(lambd * xs)\]

where ‘lambd’ stands for a hyperparameter configured by user.

Then, calculate offset value with respect to each axis, and clip:

\[max\_axis = max(Xq, axis)\]
\[offset = broadcast\_mul(max\_axis - var)\]
\[offset\_c = clip(norm, xprec)\]

Next, calculate norm and clip:

\[norm = relu(Xq - offset)\]
\[norm = clip(norm, xprec)\]

Step 3. Create Lookup Table

\[dim = alpha + 1\]
\[data = range(0, dim)\]
\[table = exp(data / xs)\]
\[table\_prec = get_bit(exp(||table||_2))\]

for reference of ‘get_bit’, see mrt.tfm_utils.get_bit.

\[table\_c = clip(table, min=0, max=get_range(table\_prec))\]

for reference of ‘get_range’, see mrt.tfm_utils.get_range.

\[weight = reshape(round(table\_c), (dim, 1))\]

Step 4. Get Lookup value

The cvm customized operator ‘cvm_lut’ has been adopted.

\[lut = cvm\_lut(norm, weight, dim)\]

Step 5. Get Output

\[sum_lut = sum(lut, axis)\]
\[oprec = min(15, 31-tprec)\]
\[oscale = get_range(oprec)\]
\[prob = lut * oscale\]
\[half_lut = realize(sum_lut, 1, 31)\]

for reference of ‘realize’, see mrt.tfm_utils.realize.

\[prob\_b = prob + half_lut\]
\[prob = prob\_b / sum_lut\]
class mrt.tfm_ops.Pooling
rewrite(op, **kwargs)

Customized rewrite pass Introduction.

Case 1. ‘pool_type’ is ‘avg’ and ‘global_pool’ is True

\[scale\_sym = 1 / (xshp[2] * xshp[3])\]

where ‘xshp’ is the infer shape of the input ‘X’.

\[op\_s = sum(x, axis=(2, 3))\]
\[op = op\_s * scale\_sym\]

Case 2. ‘pool_type’ is ‘avg’ and ‘global_pool’ is False

conv_attr = {
    'no_bias': 'True',
    'dilate': '(1, 1)',
    'kernel': kernel,
    'stride': stride,
    'pad': pad,
    'layout': 'NCHW',
    'num_filter': xshp[1],
    'num_group': xshp[1],
}

where ‘kernel’ is the pooling kernel size, ‘stride’ is the stride for pooling, ‘pad’ is the pad for pooling.

The ‘Activation’ operator could be converted into ‘Convolution’. First, set up the attributes:

\[W = full(shape=wshp, val=1/product(kernel))\]
\[op = Convolution(X, W, conv\_attr)\]
validate(op, **kwargs)

Customized validate pass Introduction.

The ‘layout’ only support ‘NCHW’.

The ‘pool_type’ only support ‘max’ and ‘avg’. And if ‘pool_type’ is ‘avg’ and ‘pooling_convention’ is ‘full’, then ‘global_pool’ must be True. And if ‘pool_type’ is ‘avg’ and ‘pooling_convention’ is not ‘full’, then ‘pooling_convention’ must be ‘valid’ and ‘global_pool’ must be True.

The ‘count_include_pad’ must be True.

class mrt.tfm_ops.BroadcastMul
prepare_for_compile(op, **kwargs)

Customized prepare_for_compile pass Introduction.

If either one of the input equals to zero, the op can be merged into zero.

quantize(op, **kwargs)

Customized quantize pass Introduction.

\[Xq, xprec, xs = requant(X, oprec)\]
\[Bq, bprec, bs = requant(B, oprec)\]

where ‘oprec’ stands for the default quantization precision for ‘BroadcastMul’.

See mrt.tfm_utils.requant for reference.

\[op = Xq * Bq\]

The infer precision equals to ‘xprec’ plus ‘bprec’.

class mrt.tfm_ops.BroadcastAdd
quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_scale for reference.

class mrt.tfm_ops.BroadcastSub
quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_scale for reference.

class mrt.tfm_ops.BroadcastTo
op_name = 'broadcast_to'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.BroadcastGreater
op_name = 'broadcast_greater'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Concat
fuse_transpose(op, **kwargs)

Customized fuse_transpose pass Introduction.

Suppose the inputs are all ‘Transpose’ about ‘axis’:

    cA         cB   ..    cC
    |          |          |
    |          |          |
Transpose  Transpose  Transpose
  (axis)    (axis)  ..  (axis)
       \       |        /
        \      |       /
         \     |      /
          \    |     /
             Concat

then, the graph can be transformed into:

cA    cB .. cC
 \    |    /
  \   |   /
   \  |  /
    Concat
      |
  Transpose
    (axis)

where ‘Transpose’ here is also about ‘axis’.

quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_scale for reference.

class mrt.tfm_ops.Sum
fuse_transpose(op, **kwargs)

Equivalent graph tranposition.

In case that at least one of the two adjacent ops is Transpose, the other op may either be swappable or fusable with Transpose.

Do nothing by default.

quantize(op, **kwargs)

Customized quantize pass Introduction.

\[Xq, xprec, xs = requant(X, oprec)\]

where ‘oprec’ stands for the default quantization precision for ‘Sum’.

See mrt.tfm_utils.requant for reference.

\[k = int(product(ishp) / product(oshp))\]
\[kprec = get\_bit\_cnt(k)\]

where ‘ishp’ and ‘oshp’ respectively stands for the input shape and output shape.

See mrt.tfm_utils.get_bit_cnt for reference.

The infer precision equals to ‘xprec’ plus ‘kprec’.

class mrt.tfm_ops.BatchNorm
rewrite(op, **kwargs)

Customized rewrite pass Introduction.

Case 1. Convolution Input

When the graph looks like this:

X    W    B
 \   |   /
Convolution gamma beta data_mean data_var
     |        |     |     |         |
     |        |     |     |         |
     --------------------------------
                    |
                BatchNorm
\[sc = gamma / sqrt(data_var + eps)\]
\[weight = W * sc.reshape(sc.shape, 1, 1, 1)\]
\[bias = beta - sc * data\_mean + B\]
\[op = Convolution(X, weight, bias, conv_attr)\]

where ‘conv_attr’ is the attribute of ‘X’.

Case 2. Other Cases

X  gamma beta data_mean data_var
|    |     |     |         |
|    |     |     |         |
----------------------------
           |
       BatchNorm
rshp = [s if i == axis else 1 for i, s in enumerate(ishp)]

where ‘axis’ is attribute of operator and ishp is the input shape of ‘X’

\[weight = reshape(rc, rshp)\]
\[bias = reshape(beta - sc * data\_mean, rshp)\]
\[op = X * weight + bias\]
class mrt.tfm_ops.Flatten
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'Flatten'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Floor
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'floor'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Ceil
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'ceil'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Round
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'round'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Fix
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'fix'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Cast
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'Cast'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Slice
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'slice'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Reshape
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'Reshape'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Custom
validate(op, **kwargs)

Customized validate pass Introduction.

The op type only support ‘cvm_clip’, cvm_left_shift, ‘cvm_right_shift’, ‘cvm_lut’.

class mrt.tfm_ops.Clip
fuse_transpose(op, **kwargs)

Customized fuse_transpose pass Introduction.

See mrt.tfm_ops.reverse_transpose for reference.

quantize(op, **kwargs)

Customized quantize pass Introduction.

\[amin = int(a\_min * iscale)\]
\[amax = int(a\_max * iscale)\]

where ‘a_min’ and ‘a_max’ are attributes of op, and ‘iscale’ is the input scale of the input.

\[op = clip(x, amin, amax)\]
class mrt.tfm_ops.Minimum
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = '_minimum'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Maximum
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = '_maximum'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Max
op_name = 'max'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Min
op_name = 'min'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Argmax
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'argmax'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Argmin
compile(op, **kwargs)

Compile mxnet symbol into nnvm symbol.

Throw exception by default.

op_name = 'argmin'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Abs
op_name = 'abs'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.ElemwiseAdd
fuse_transpose(op, **kwargs)

Customized fuse_transpose pass Introduction.

See mrt.tfm_ops._quantize_scale for reference.

quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_scale for reference.

class mrt.tfm_ops.ElemwiseSub
fuse_transpose(op, **kwargs)

Customized fuse_transpose pass Introduction.

See mrt.tfm_ops._quantize_scale for reference.

quantize(op, **kwargs)

Customized quantize pass Introduction.

See mrt.tfm_ops._quantize_scale for reference.

class mrt.tfm_ops.ElemwiseMul
rewrite(op, **kwargs)

validate the infer_shapes of lhs and rhs must be the same thus this op could be rewrite into broadcast_mul corresponding cvm op would be optimized at compile time

class mrt.tfm_ops.Dropout
fuse_transpose(op, **kwargs)

Customized fuse_transpose pass Introduction.

See mrt.tfm_ops.reverse_transpose for reference.

class mrt.tfm_ops.Arange
op_name = '_arange'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Tile
op_name = 'tile'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Negative
op_name = 'negative'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.SwapAxis
rewrite(op, **kwargs)

Customized rewrite pass Introduction.

\[ndims = len(ishp)\]

where ‘ishp’ is the infer shape of the input.

\[new_axis = range(ndims)\]

where ‘dim1’ and ‘dim2’ is the attributes of op.

class mrt.tfm_ops.PlusScalar
op_name = '_plus_scalar'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

rewrite(op, **kwargs)

Customized rewrite pass Introduction.

Transform into broadcast_add.

class mrt.tfm_ops.ZerosLike
rewrite(op, **kwargs)

Customized quantize pass Introduction.

Make constant zeros with fixed shape.

class mrt.tfm_ops.OnesLike
rewrite(op, **kwargs)

Customized quantize pass Introduction.

Make constant ones with fixed shape.

class mrt.tfm_ops.GreaterScalar
validate(op, **kwargs)

Customized validate pass Introduction.

Only support integer scalar.

class mrt.tfm_ops.Where
op_name = 'where'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.Squeeze
op_name = 'squeeze'

Transformer Operator Name

Transformer is associated with operator which is defined

in mxnet, and the variable indicates the type name of mxnet symbol.

Attention please, the base transformer should not be instantiated

since it’s just an abstarct aggregation of graph pass, and it’s named none by default.

class mrt.tfm_ops.BatchDot
quantize(op, **kwargs)

Customized quantize pass Introduction.

The inputs are quantized into the same precision level.

rewrite(op, **kwargs)

Customized rewrite pass Introduction.

Using matrix decomposition to avoid overflow.

\[Y = A \cdot B = A_1 \cdot B_1 + A_2 \cdot B_2 + ...\]

where

\[A_i\text{.shape} = (batch, M, step)\]
\[B_i\text{.shape} = (batch, step, N)\]
class mrt.tfm_ops.BroadcastLike
rewrite(op, **kwargs)

Customized rewrite pass Introduction.

The original operator:

\[op = broadcast\_like(X, W, lhs\_axes, rhs\_axes)\]

can be transformed in threee conditions.

Case 1. Null Attributes

\[mul = broadcast\_mul(W, 0)\]
\[add = broadcast\_add(mul, 1)\]
\[op = broadcast\_mul(X, add)\]

Case 2. Batch Axis not in Attributes

Calculate attributes for tile as follows:

cnts = {v: wshp[rhs_axes[i]] for i, v in enumerate(lhs_axes)}
reps = tuple([cnts[v] if v in lhs_axes else 1 for v in range(xndims)])

where wshp is the shape of input weight.

\[op = tile(X, reps=reps)\]

Case 3. Other Cases

In this case, we only support injection of batchaxis from lhs_axes to rhs_axes.

After transformation of weight (see source code for ref), result was calculated as in case 1.

class mrt.tfm_ops.ReshapeLike
rewrite(op, **kwargs)

Customized rewrite pass Introduction.

Equivalently transformed into reshape. Since dynamic shape fusion is supported, only single batch axis is supported rather than multiple ones.

mrt.tfm_ops._quantize_scale(op, **kwargs)

quantization function with the inputs form of:

\[Y = f(node1, node2, node3, ...)\]

where node1, node2, node3, … ought to be quantized to the same scale.

The infer precision depends on the input with the maximum quantized tight precision.

mrt.tfm_ops._quantize_xwb(op, **kwargs)

quantization function with the inputs form of:

\[Y = X*W + B\]

The input and weight are quantized into the same precision level. Bias is quantized with respect to the product of input and weight.

the infer precision equals to the sum of quantized input precision, quantized weight precision and the product precision.

mrt.tfm_ops._quantize_table(op, **kwargs)

quantization function with the inputs form of:

\[Y = f(X) = g(exp(X))\]

Step 1. Requant Input

\[Xq, xprec, xs = requant\_operator(X, iprec, oscale)\]

where ‘Xq’ is quantized symbol, ‘xprec’ and ‘xs’ stand for quantized precision and scale, for reference of ‘requant_operator’, see mrt.tfm_utils.requant_operator.

\[offset = BroadcastAdd(Xq, alpha)\]

where alpha is the threshold of ‘Xq’

Step 2. Create Lookup Table

\[r = range(-alpha, alpha+1) / xs\]

where ‘r’ stands for the table range of the unscaled input.

\[out = f(r)\]
\[oscale = scale(||out||_2, xprec)\]

for reference of ‘scale’, see mrt.tfm_utils.scale.

\[in\_dim = 2*alpha+1\]
\[out\_q = round(out * oscale)\]

Util now, the float-simulating-int process has been accomplished.

\[W = reshape(round(out_q * oscale), in\_dim)\]

Util now, the lookup table has been created.

Step 3. Get Output

The cvm customized operator ‘cvm_lut’ has been adopted.

\[op = cvm\_lut(X, W, in\_dim)\]
mrt.tfm_ops.reverse_sequence(op)

Reverse the symbol sequenze may leads to error of the different result, due to the graph unequaivent transformer.

Example

A ->  B -> C
  |-> D -> E

after reverse sequence is

B -> A ->  C
       |-> D -> E

which is invalid.

Notice:

The fuse_transpose pass have the same hidden problems.

mrt.tfm_ops.reverse_transpose(op)

For symbol with single Transpose input, reverse these sequence if this two op is swapable.

X -> Transpose -> op

after reverse sequence is

X -> op -> Transpose
Notice:

After and before swap the axis of the Transpose remains the same.

mrt.cvm_op

Customized Op-level realization of MxNet Forward Computing Framework.

MxNet customized operator property class.

Only crucial parts of the custommized forward operator implementation are elaborated.

class mrt.cvm_op.Clip(precision, **kwargs)
forward(is_train, req, in_data, out_data, aux)

MxNet customized operator forward implementation.

Clip the input within [-2^prec, 2^prec].

\[rnd = round(X)\]

where X is the input tensor.

\[out = clip(rnd, -2^{prec}, 2^{prec})\]

where prec is an operator attribute representing integer bits.

class mrt.cvm_op.LeftShift(precision, shift_bit, **kwargs)
forward(is_train, req, in_data, out_data, aux)

MxNet customized operator forward implementation.

Left shift data for sb bits and clip within [-2^prec, 2^prec].

\[rnd = round(X)\]

where X is the input tensor.

\[val = rnd * 2^{sb}\]
\[out = clip(val1, -2^{prec}, 2^{prec})\]

where prec is an operator attribute representing integer bits, sb represents the left shift bits.

class mrt.cvm_op.RightShift(precision, shift_bit, **kwargs)
forward(is_train, req, in_data, out_data, aux)

MxNet customized operator forward implementation.

Right shift data for sb bits and clip within [-2^prec, 2^prec].

\[rnd = round(X)\]

where X is the input tensor.

\[val0 = floor(rnd / 2^{sb-1})\]
\[val1 = floot((val0+1) / 2)\]
\[out = clip(val1, -2^{prec}, 2^{prec})\]

where prec is an operator attribute representing integer bits, sb represents the right shift bits.

class mrt.cvm_op.LUT(in_dim, **kwargs)
forward(is_train, req, in_data, out_data, aux)

MxNet customized operator forward implementation.

Embed X with respect to T with vocabulary size of indim. The dimension of the embedding vectors is 1.

where X is the input tensor and T is the weight tensor.

\[val = Embedding(X, T, indim, 1)\]
\[out = squeeze(val, axis=-1)\]

where indim is an operator attribute representing vocabulary size of input indices.

class mrt.cvm_op.ClipProp(precision=8, shift_bit=0)

MxNet cvm_clip operator property class.

class mrt.cvm_op.LeftShiftProp(precision=8, shift_bit=0)

MxNet cvm_left_shift operator property class.

class mrt.cvm_op.RightShiftProp(precision=8, shift_bit=0)

MxNet cvm_right_shift operator property class.

class mrt.cvm_op.LUTProp(in_dim)

MxNet cvm_lut operator property class.