MRT main2 API

Main2 Stages

The main stages of main2 include the following 6 stages, which are enumerated as follows.

Prepare

The raw model need to be prepared in order to be compatible with the mrt pipeline:

model = Model.load(sym_path, prm_path)
model.prepare(set_batch(input_shape, 1))

Split Model

In a model split operation, a model can be split into top and base by specifying keys.

base, top = model.split(keys)

Calibration

A mrt instance is created for calibration and quantization stage.

mrt = model.get_mrt() if keys == '' else base.get_mrt()

The calibration can be executed by specifying the number of calibration, lambd and data.

for i in range(calibrate_num):
    data, _ = data_iter_func()
    mrt.set_data(data)
    mrt.calibrate(lambd=lambd, ctx=ctx)

Quantization

A mrt instance can perform the quantization process, the user can set up some predefined parameters for mrt if needed, such as input precision, output precision, softmax lambd, shift bits as well as threshold for a particular node, etc.

mrt.set_input_prec(input_precision)
mrt.set_output_prec(output_precision)
mrt.set_softmax_lambd(softmax_lambd)
mrt.set_shift_bits(shift_bits)
mrt.set_threshold(name, threshold)

Then, the quantization process is performed as follows:

mrt.quantize()

Merge Model

By specifying the base and top models along with corresponding node key maps, the user can create a model merger instance.

model_merger = Model.merger(qmodel, top, mrt.get_maps())

By specifying callback merging function, the user can merge the top and base models, and get the ouput scales by configure oscale_maps.

qmodel = model_merger.merge(callback=mergefunc)
oscale_maps = _get_val(
    cfg, sec, 'Oscale_maps', dtype=PAIR(str_t, str_t))
oscales = model_merger.get_output_scales(
    mrt_oscales, oscale_maps)

Evaluation

Quantized model reduction and performance comparison are implemented in the evaluation stage:

org_model = Model.load(sym_path, prm_path)
graph = org_model.to_graph(ctx=ctx)
dataset = ds.DS_REG[ds_name](set_batch(input_shape, batch))
data_iter_func = dataset.iter_func()
metric = dataset.metrics()

...

split_batch = batch//ngpus
rqmodel = reduce_graph(qmodel, {
    'data': set_batch(input_shape, split_batch)})
qgraph = rqmodel.to_graph(ctx=ctx)
qmetric = dataset.metrics()

...

utils.multi_validate(evalfunc, data_iter_func, quantize,
                     iter_num=iter_num,
                     logger=logging.getLogger('mrt.validate'),
                     batch_size=batch)

Compilation

Compilation stage include model conversion from mxnet to cvm, and model dump:

qmodel.to_cvm(model_name_tfm, datadir=dump_dir,
    input_shape=set_batch(input_shape, batch),
    target=device_type, device_ids=device_ids)

as well as dump of sample data and ext files:

dump_data = sim.load_real_data(
    dump_data.astype("float64"), 'data', mrt.get_inputs_ext())
model_root = path.join(dump_dir, model_name_tfm)
np.save(path.join(model_root, "data.npy"),
        dump_data.astype('int8').asnumpy())
infos = {
    "inputs_ext": inputs_ext,
    "oscales": oscales,
    "input_shapes": input_shape,
}
sim.save_ext(path.join(model_root, "ext"), infos)

Main2 Helper Functions

This is a user API that is used to parse model configurations.

mrt.main2.set_batch(input_shape, batch)

Get the input shape with respect to a specified batch value and an original input shape.

Parameters
  • input_shape (tuple) – The input shape with batch axis unset.

  • batch (int) – The batch value.

Returns

ishape – The input shape with the value of batch axis equal to batch.

Return type

tuple

mrt.main2.batch_axis(input_shape)

Get the batch axis entry of an input shape.

Parameters

input_shape (tuple) – The data shape related to dataset.

Returns

axis – The batch axis entry of an input shape.

Return type

int

mrt.main2._check(expression, section, option, message='Not a valid value')

check whether an operation of main2 if valid and report error message if invalid.

Parameters
  • expression (bool) – The judgement conditions in main2.

  • section (string) – The section of configuration file.

  • option (string) – The option of the section.

  • message (string) – The error message to be reported.

mrt.main2._get_path(config, section, option, is_dir=False, dpath=<object object>)

Get and validate the path specified in configuration file.

Parameters
  • config (configparser.ConfigParser) – The initialized config parser.

  • section (string) – The section of configuration file.

  • option (string) – The option of the section.

  • is_dir (bool) – Whether the path is a directory.

  • dpath (string) – The default path.

Returns

path – The verified absolute path specified in the option.

Return type

string

mrt.main2._get_ctx(config, section, dctx=cpu(0))

Get the context specified in configuration file.

Parameters
  • config (configparser.ConfigParser) – The initialized config parser.

  • section (string) – The section of configuration file.

  • dctx (mxnet.context) – The default context.

Returns

path – The context specified in the option.

Return type

mxnet.context

mrt.main2.ARRAY(dtype)

Array wrapper of the uniform data type.

Parameters

dtype (string) – The data type to be uniformly wrapped into an array.

Returns

ret – The wrapped data type name.

Return type

string

mrt.main2.PAIR(*dtypes)

Multi-level map wrapper of the uniform data types.

Parameters

dtypes (list of string) – The data types to be uniformly wrapped into an multi-level map.

Returns

ret – The wrapped data type name.

Return type

string

mrt.main2._get_val(config, section, option, dtype='_str_', dval=<object object>)

Get the value of the option in the section, with data type and default value specified.

Parameters
  • config (configparser.ConfigParser) – The initialized config parser.

  • section (string) – The section of configuration file.

  • option (string) – The option of the section.

  • dtype (string) – The data type to be recognised by the parser.

  • dval (string, int, etc) – The default value.

  • message (string) – The error message to be reported.

Returns

val – The parsed value.

Return type

string, int, etc

mrt.main2._cast_val(section, option, val_, dtype='_str_')

Get the value of the option in the section, with data type and default value specified.

Parameters
  • config (configparser.ConfigParser) – The initialized config parser.

  • section (string) – The section of configuration file.

  • option (string) – The option of the section.

  • dtype (string) – The data type to be recognised by the parser.

  • dval (string, int, etc) – The default value.

  • message (string) – The error message to be reported.

Returns

val – The parsed value.

Return type

string, int, etc

mrt.main2._load_fname(prefix, suffix=None, with_ext=False)

Get the model files at a given stage.

Parameters
  • prefix (string) – The file path without and extension.

  • suffix (string) – The file suffix with respect to a given stage of MRT.

  • with_ext (bool) – Whether to include ext file.

Returns

files – The loaded file names.

Return type

tuple of string

mrt.main2._checkpoint_exist(sec, *flist)

Check whether the given file satisfy the check point of the MRT Stage.

Parameters
  • sec (string) – The MRT stage to be checked.

  • flist (list of string) – The checkpoint file to be checked.