MRT main2 API¶
Contents
Main2 Stages¶
The main stages of main2 include the following 6 stages, which are enumerated as follows.
Prepare¶
The raw model need to be prepared in order to be compatible with the mrt pipeline:
model = Model.load(sym_path, prm_path)
model.prepare(set_batch(input_shape, 1))
Split Model¶
In a model split operation, a model can be split into top and base by specifying keys.
base, top = model.split(keys)
Calibration¶
A mrt instance is created for calibration and quantization stage.
mrt = model.get_mrt() if keys == '' else base.get_mrt()
The calibration can be executed by specifying the number of calibration, lambd and data.
for i in range(calibrate_num):
data, _ = data_iter_func()
mrt.set_data(data)
mrt.calibrate(lambd=lambd, ctx=ctx)
Quantization¶
A mrt instance can perform the quantization process, the user can set up some predefined parameters for mrt if needed, such as input precision, output precision, softmax lambd, shift bits as well as threshold for a particular node, etc.
mrt.set_input_prec(input_precision)
mrt.set_output_prec(output_precision)
mrt.set_softmax_lambd(softmax_lambd)
mrt.set_shift_bits(shift_bits)
mrt.set_threshold(name, threshold)
Then, the quantization process is performed as follows:
mrt.quantize()
Merge Model¶
By specifying the base and top models along with corresponding node key maps, the user can create a model merger instance.
model_merger = Model.merger(qmodel, top, mrt.get_maps())
By specifying callback merging function, the user can merge the top and base models, and get the ouput scales by configure oscale_maps.
qmodel = model_merger.merge(callback=mergefunc)
oscale_maps = _get_val(
cfg, sec, 'Oscale_maps', dtype=PAIR(str_t, str_t))
oscales = model_merger.get_output_scales(
mrt_oscales, oscale_maps)
Evaluation¶
Quantized model reduction and performance comparison are implemented in the evaluation stage:
org_model = Model.load(sym_path, prm_path)
graph = org_model.to_graph(ctx=ctx)
dataset = ds.DS_REG[ds_name](set_batch(input_shape, batch))
data_iter_func = dataset.iter_func()
metric = dataset.metrics()
...
split_batch = batch//ngpus
rqmodel = reduce_graph(qmodel, {
'data': set_batch(input_shape, split_batch)})
qgraph = rqmodel.to_graph(ctx=ctx)
qmetric = dataset.metrics()
...
utils.multi_validate(evalfunc, data_iter_func, quantize,
iter_num=iter_num,
logger=logging.getLogger('mrt.validate'),
batch_size=batch)
Compilation¶
Compilation stage include model conversion from mxnet to cvm, and model dump:
qmodel.to_cvm(model_name_tfm, datadir=dump_dir,
input_shape=set_batch(input_shape, batch),
target=device_type, device_ids=device_ids)
as well as dump of sample data and ext files:
dump_data = sim.load_real_data(
dump_data.astype("float64"), 'data', mrt.get_inputs_ext())
model_root = path.join(dump_dir, model_name_tfm)
np.save(path.join(model_root, "data.npy"),
dump_data.astype('int8').asnumpy())
infos = {
"inputs_ext": inputs_ext,
"oscales": oscales,
"input_shapes": input_shape,
}
sim.save_ext(path.join(model_root, "ext"), infos)
Main2 Helper Functions¶
This is a user API that is used to parse model configurations.
- mrt.main2.set_batch(input_shape, batch)¶
Get the input shape with respect to a specified batch value and an original input shape.
- Parameters
input_shape (tuple) – The input shape with batch axis unset.
batch (int) – The batch value.
- Returns
ishape – The input shape with the value of batch axis equal to batch.
- Return type
tuple
- mrt.main2.batch_axis(input_shape)¶
Get the batch axis entry of an input shape.
- Parameters
input_shape (tuple) – The data shape related to dataset.
- Returns
axis – The batch axis entry of an input shape.
- Return type
int
- mrt.main2._check(expression, section, option, message='Not a valid value')¶
check whether an operation of main2 if valid and report error message if invalid.
- Parameters
expression (bool) – The judgement conditions in main2.
section (string) – The section of configuration file.
option (string) – The option of the section.
message (string) – The error message to be reported.
- mrt.main2._get_path(config, section, option, is_dir=False, dpath=<object object>)¶
Get and validate the path specified in configuration file.
- Parameters
config (configparser.ConfigParser) – The initialized config parser.
section (string) – The section of configuration file.
option (string) – The option of the section.
is_dir (bool) – Whether the path is a directory.
dpath (string) – The default path.
- Returns
path – The verified absolute path specified in the option.
- Return type
string
- mrt.main2._get_ctx(config, section, dctx=cpu(0))¶
Get the context specified in configuration file.
- Parameters
config (configparser.ConfigParser) – The initialized config parser.
section (string) – The section of configuration file.
dctx (mxnet.context) – The default context.
- Returns
path – The context specified in the option.
- Return type
mxnet.context
- mrt.main2.ARRAY(dtype)¶
Array wrapper of the uniform data type.
- Parameters
dtype (string) – The data type to be uniformly wrapped into an array.
- Returns
ret – The wrapped data type name.
- Return type
string
- mrt.main2.PAIR(*dtypes)¶
Multi-level map wrapper of the uniform data types.
- Parameters
dtypes (list of string) – The data types to be uniformly wrapped into an multi-level map.
- Returns
ret – The wrapped data type name.
- Return type
string
- mrt.main2._get_val(config, section, option, dtype='_str_', dval=<object object>)¶
Get the value of the option in the section, with data type and default value specified.
- Parameters
config (configparser.ConfigParser) – The initialized config parser.
section (string) – The section of configuration file.
option (string) – The option of the section.
dtype (string) – The data type to be recognised by the parser.
dval (string, int, etc) – The default value.
message (string) – The error message to be reported.
- Returns
val – The parsed value.
- Return type
string, int, etc
- mrt.main2._cast_val(section, option, val_, dtype='_str_')¶
Get the value of the option in the section, with data type and default value specified.
- Parameters
config (configparser.ConfigParser) – The initialized config parser.
section (string) – The section of configuration file.
option (string) – The option of the section.
dtype (string) – The data type to be recognised by the parser.
dval (string, int, etc) – The default value.
message (string) – The error message to be reported.
- Returns
val – The parsed value.
- Return type
string, int, etc
- mrt.main2._load_fname(prefix, suffix=None, with_ext=False)¶
Get the model files at a given stage.
- Parameters
prefix (string) – The file path without and extension.
suffix (string) – The file suffix with respect to a given stage of MRT.
with_ext (bool) – Whether to include ext file.
- Returns
files – The loaded file names.
- Return type
tuple of string
- mrt.main2._checkpoint_exist(sec, *flist)¶
Check whether the given file satisfy the check point of the MRT Stage.
- Parameters
sec (string) – The MRT stage to be checked.
flist (list of string) – The checkpoint file to be checked.