-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Python script to convert nnet2 to nnet3 models #1611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
"Uses numpy to combine matrices and vectors into a single matrix for initialisation. Possible to do using just file manipulation, just a bit more messy." I think this dependence on |
KNOWN_COMPONENTS = NODE_NAMES.keys() | ||
# End configuration section | ||
|
||
logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not familiar with the latest logging guidelines. @vimalmanohar could you take a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logging is fine. Although, the guidelines for variable names and spacing in the Google style guide are not followed. https://google.github.io/styleguide/pyguide.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are some comments after a quick initial review.
if (result != 0): | ||
raise OSError('Encountered an error writing the model.') | ||
|
||
def ParseNnet2(line_buffer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to rename this as ParseNnet2ToNnet3
|
||
def ConsumeToken(token, line): | ||
'''Returns line without token''' | ||
if token != line.split()[0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC these lines can be very large. So it might be better to just check for this using regexes or at least using split() method using something like a maxsplit=1 option.
|
||
def WriteModel(self, model, binary="true"): | ||
result = 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that self.config is a proper nnet3 config file.
self.priors, | ||
model), shell=True) | ||
|
||
if (result != 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it might be worth adding a check in a top level shell script which performs forward prop through the nnet2 and nnet3 models and asserts that the values are within acceptable threshold ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if there is a easy way to check the correctness of the transition model without doing a decode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. How about another argument to the argparser for some features, and if present, will run a forward pass and compare the results using difflib
; so keeping it all in the single Python script?
The transition model isn't read by Numpy, so shouldn't have any numerical errors I think.
result = 0 | ||
|
||
# write raw model | ||
result += subprocess.call('nnet3-init --binary=true {0} {1}'.format(self.config, os.path.join(tmpdir, 'nnet3.raw')), shell=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to consider using the KaldiCommand function available in the nnet3 python libraries ? This might help you handle some common errors.
for component in self.components: | ||
if component.ident == "splice": | ||
# Create splice string for the next node | ||
previous_component=MakeSpliceString(previous_component, component.pairs['<Context>']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not clear how you are handling the SpliceMaxComponent
here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I had forgot about SpliceMaxComponent - it's not explicitly handled at the minute. What is the equivalent in nnet3? I couldn't find a max-descriptor (e.g. 'Max(Offset, ...)')
@@ -0,0 +1,441 @@ | |||
#!/usr/bin/env python | |||
# Copyright 2016 | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add your name to the author list.
Function/method names, spaces, column width.
This is not something that can be done using a descriptor. You would need
to add a component. I think Max pooling component can be used but I don't
remember if the data needs to be rearranged. However I know of no
architectures which were using splicemaxcomponent in nnet2, so I think you
can just raise an error rather than handling this case.
…On May 9, 2017 3:02 AM, "Joachim" ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In egs/wsj/s5/steps/nnet3/convert_nnet2_to_nnet3.py
<#1611 (comment)>:
> + self.config = filename
+ with open(filename, 'w') as f:
+ for component in self.components:
+ if component.ident == "splice":
+ continue
+ config_string = ' '.join(component.pairs)
+
+ f.write('component name={name} type={comp_type} {config_string}\n'.format(name=component.ident, comp_type=component.component, config_string=config_string))
+
+ f.write('\n# Component nodes\n')
+ f.write('input-node name=input dim={0}\n'.format(self.input_dim))
+ previous_component='input'
+ for component in self.components:
+ if component.ident == "splice":
+ # Create splice string for the next node
+ previous_component=MakeSpliceString(previous_component, component.pairs['<Context>'])
Sorry, I had forgot about SpliceMaxComponent - it's not explicitly handled
at the minute. What is the equivalent in nnet3? I couldn't find a
max-descriptor (e.g. 'Max(Offset, ...)')
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#1611 (comment)>, or mute
the thread
<https://github.com./notifications/unsubscribe-auth/ADtwoKMhBNXJMC7hAigNiFpdMICSK5riks5r4DmbgaJpZM4NTGyg>
.
|
So @jfainberg, this is ready to commit? Still says WIP. |
@vijayaditya if you say it's still good I'll merge. |
Yes @danpovey, please go ahead. I've tested it (forward pass) with typical nnet2 pnorm and tanh networks. |
@danpovey Went through it briefly. LGTM. |
* 'master' of https://github.com./kaldi-asr/kaldi: (140 commits) [egs] Fix failure in multilingual BABEL recipe (regenerate cmvn.scp) (kaldi-asr#1686) [src,scripts,egs] Backstitch code+scripts, and one experiment, will add more later. (kaldi-asr#1605) [egs] CNN+TDNN+LSTM experiments on AMI (kaldi-asr#1685) [egs,scripts,src] Tune image recognition examples; minor small changes. (kaldi-asr#1682) [src] Fix bug in looped computation (kaldi-asr#1673) [build] when installing sequitur and mmseg, look for lib64 as well (thanks: @akshayc11) (kaldi-asr#1677) [src] fix to gst-plugin/Makefile (remove -lkaldi-thread) (kaldi-asr#1680) [src] Cosmetic fixes to usage messages [egs] Fix to some --proportional-shrink related example scripts (kaldi-asr#1674) [build] Fix small bug in configure [scripts] Fix small bug in utils/gen_topo.pl. [scripts] Add python script to convert nnet2 to nnet3 models (kaldi-asr#1611) [doc] Fix typo (kaldi-asr#1669) [src] nnet3: fix small bug in checking code. Thanks: @Maddin2000. [src] Add #include missing from previous commit [src] Fix bug in online2-nnet3 decoding RE dropout+batch-norm (thanks: Wonkyum Lee) [scripts] make errors getting report non-fatal (thx: Miguel Jette); add comment RE dropout proportion [src,scripts] Use ConstFst or decoding (half the memory; slightly faster). (kaldi-asr#1661) [src] keyword search tools: fix Minimize() call, necessary due to OpenFst upgrade (kaldi-asr#1663) [scripts] do not fail if the ivector extractor belongs to different user (kaldi-asr#1662) ...
Re #886 .
Initial implementation - I'll test it a bit more over the next few days, but open for comments.
Couple of points:
numpy
to combine matrices and vectors into a single matrix for initialisation. Possible to do using just file manipulation, just a bit more messy.