MRT Transformer API

mrt.transformer

MRT Interface API

Refractor of source code, using the registry pattern. Rules of coding with pylint. Collection of hyper-parameters controller. Simplification of public API.

mrt.transformer.Model

class mrt.transformer.Model(symbol, params, dtype='float64')

Wrapper of Mxnet symbol and params, design with user-friendly model API.

input_names()

List model input names.

static load(symbol_file, params_file)

Model load from disk.

output_names()

List model output names.

save(symbol_file, params_file)

Model dump to disk.

to_graph(dtype='float32', ctx=cpu(0))

Convenient helper function to create model runtime, returns gluon.nn.SymbolBlock.

mrt.transformer.MRT

class mrt.transformer.MRT(model, input_prec=8)

An MRT quantization class contained many helper functions.

Quantization Procedures

  1. prepare: initial of model graph, such as fuse_constant, rewrite, validate, …etc;

  2. calibration: caculate the internal thresholds of layers;

3. quantization: quantize the floating parameters into INT(p) precision with scales, using the floading data simulate the realized environment of interger dataflow;

calibrate(ctx=cpu(0), lambd=None, old_ths=None)

Calibrate the current model after setting mrt data.

Parameters
  • ctx (mxnet.context) – Context on which intermediate result would be stored,

  • lambd (double) – Hyperparameter

  • old_ths (dict) – Reference threshold dict could also be specified.

Returns

th_dict – Threshold dict of node-level output.

Return type

dict

get_inputs_ext()

Get the input_ext of the input after quantization.

get_maps()

Get the current name to old name map of the outputs after calibration or quantization.

get_output_scales()

Get the output scale of the model after quantization.

static load(model_name, datadir='./data')

Load and create a mrt instance.

The given path should contain corresponding ‘.json’ and ‘.params’ file storing model information and ‘.ext’ file storing mrt information.

Returns

mrt – The mrt instance.

Return type

MRT

quantize()

Quantize the current model after calibration.

Returns

qmodel – The quantized model.

Return type

Model

save(model_name, datadir='./data')

Save the current mrt instance into disk.

set_data(data)

Set the data before calibration.

set_input_prec(prec)

Set the input precision before quantization.

set_output_prec(prec)

Set the output precision before quantization.

set_restore(name)

Manually set the threshold of the node output, given node name.

set_shift_bits(val)

Set the hyperparameter shift_bits before quantization.

set_softmax_lambd(val)

Set the hyperparameter softmax_lambd before quantization.

set_th_dict(th_dict)

Manually set the threshold dict.

set_threshold(name, threshold)

Manually set the threshold dict.

mrt.transformer.ModelMerger

class mrt.transformer.ModelMerger(base_model, top_model, base_name_maps=None)

A wrapper class for model merge tool.

get_output_scales(base_oscales, maps)

Get the model output scales after merge.

Parameters
  • base_oscales (list) – Base model output scales.

  • maps (dict) – Base name maps should be specified.

Returns

ret – The output scales of the merged model.

Return type

list

merge(callback=None)

Get the merged model.

Parameters

callback (func) – Callback function could also be specified for updating the top node attributes.

Returns

ret – The merged model.

Return type

Model

mrt.transformer.ModelSpliter

class mrt.transformer.ModelSpliter(model, keys)

A wrapper class for model split tool.

split()

Get the split models with respect to the specified keys.

Returns

ret – The split models.

Return type

Model tuple