Skip to content

EvalPrediction does not allow for "sources" parameter which "sari" metric requires #15966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
alisonhc opened this issue Mar 7, 2022 · 16 comments

Comments

@alisonhc
Copy link

alisonhc commented Mar 7, 2022

class EvalPrediction(NamedTuple):

Hello, I have been following this example and would like to use the sari metric, which requires sources in addition to predictions and references.

Would it be possible to modify this to allow passing in source utterances so that the compute_metrics parameter can successfully pass the appropriate information to my custom compute_metrics function? Thanks!

@lmvasque
Copy link
Contributor

lmvasque commented Mar 8, 2022

I've placed a similar request to this (I would say is the same :)), not sure if it is released. @mariosasko your advise on this?
huggingface/datasets#3818

@alisonhc
Copy link
Author

alisonhc commented Mar 9, 2022

I saw this! I think it's fixed in the datasets library but isn't fixed yet in this one.

@LysandreJik
Copy link
Member

Hey! We're happy to review PRs if any of you want to try your hand at contributing!

@alisonhc
Copy link
Author

I might be interested in contributing, but I would have quite a few questions first.

I looked at the code, and seems like transformers/src/transformers/trainer is one of the main files that calls compute_metrics with EvalPrediction. I see the inputs variable in this line. How do I get just the inputs from this variable and not the targets? One of the docstrings says inputs (Dict[str, Union[torch.Tensor, Any]]): The inputs and targets of the model.

My goal is to have a variable all_inputs just like all_labels and all_preds that contains the source utterances so that it can be used as a parameter for compute_metrics.

@lmvasque
Copy link
Contributor

lmvasque commented Mar 15, 2022

I've partially make it work with my code with the following changes, but more work is needed to actually have a solution in production since it depends on a lot of code.

Adding the inputs to trainer.py, from line 2419 (just a summary):

# losses/preds/labels on CPU (final containers)
..
        all_labels = None
..
        for step, inputs in enumerate(dataloader):
..
            inputs = inputs.data['decoder_input_ids']
..

            # Update containers on host
..
            if inputs is not None:
                inputs = self._pad_across_processes(inputs)
                inputs = self._nested_gather(inputs)
                inputs_host = inputs if inputs_host is None else nested_concat(inputs_host, inputs, padding_index=-100)
..

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
..
                if inputs_host is not None:
                    inputs = nested_numpify(inputs_host)
                    all_inputs = inputs if all_inputs is None else nested_concat(all_inputs, inputs, padding_index=-100)
..

        # Gather all remaining tensors and put them back on the CPU
..
        if inputs_host is not None:
            inputs = nested_numpify(inputs_host)
            all_inputs = inputs if all_inputs is None else nested_concat(all_inputs, inputs, padding_index=-100)
..

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
..
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)
..
        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
            metrics = self.compute_metrics(EvalPrediction(inputs=all_inputs, predictions=all_preds, label_ids=all_labels))

Then to file trainer_utils.py, from line 67:

class EvalPrediction(NamedTuple):
..
    inputs: Union[np.ndarray, Tuple[np.ndarray]]
..

And more work has to be done in the compute_metric() function from the trainer.py class. For now, I'm using my metric directly in my transformers example file:

    def compute_metrics(eval_preds):
        inputs, preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)

        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)

        # Some simple post-processing
        decoded_inputs, decoded_preds, decoded_labels = postprocess_text(decoded_inputs, decoded_preds,
                                                                         decoded_labels, "sari")
        sari_result = sari._compute(inputs=decoded_inputs, predictions=decoded_preds, references=decoded_labels)

I hope it helps with the refactoring :)

@LysandreJik
Copy link
Member

cc @sgugger

@sgugger
Copy link
Collaborator

sgugger commented Mar 28, 2022

I don't really understand the changes you suggest @lmvasque since they use a variable inputs_host that is not defined anywhere in the trainer.py file. It would be easier to study the diff of what you suggest on a PR.

Note that a line such as

inputs = inputs.data['decoder_input_ids']

can't be accepted, since it's super specific to a model (not all models have decoder_input_ids) and also relies on the data field of the batch, which doesn't always exist.

@lmvasque
Copy link
Contributor

Thanks for reviewing this @sgugger. Yes, I've just realized that this code is executed only when running on GPU. I had the chance to run it in this setting this week and yes, you are right about the changes you mention.

I've added all my changes as a pull request so you can easily review them (please use them as a reference not as a ready to go feature): #16461

About these changes:

  • These are definitely not enough for production, further changes are needed in the compute_metrics(), but the dependencies start to get messy.
  • These changes work for me by using my own version of compute metrics in my external metrics file.
  • For adding the inputs, I've replicated the code of the preds and labels across the code. However, I don't know if all of these transformations are necessary. I don't understand deeply this code to tell :)

@sgugger
Copy link
Collaborator

sgugger commented Mar 28, 2022

I had not realized you were adding a field to the EvalPrediction named tuple. That's unfortunately a breaking change we can't do, as it would break the code of every user doing evaluation with a compute_metrics function.

@lmvasque
Copy link
Contributor

Can this be supported otherwise? Research in Text Simplification uses the inputs in its main evaluation metric SARI, so we cannot use Huggingface pipeline (Datasets + Transformers) for our models (unless we hack the code for our purposes..).

@sgugger
Copy link
Collaborator

sgugger commented Mar 28, 2022

I'll look into adding something that would be backward compatible and does the same as your PR, but it might take a bit of time. In the meantime, I'd advise using a subclass of the Trainer with your custom code.

@lmvasque
Copy link
Contributor

That's sounds good, I'm happy to do that meanwhile. Thanks again for this! It would be a good step for the Simplification world :)

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@sgugger
Copy link
Collaborator

sgugger commented Apr 22, 2022

I think this can be closed now, since your PR was merged @lmvasque

@alisonhc
Copy link
Author

Thanks everyone! Closing.

@hoangthangta
Copy link

hoangthangta commented Feb 25, 2023

For everyone to use the latest version of Transformer (>=v.4.21.0 (https://newreleases.io/project/github/huggingface/transformers/release/v4.21.0) simply define: include_inputs_for_metrics = True in the training arguments.

training_args = Seq2SeqTrainingArguments(
        include_inputs_for_metrics = True,
        # other arguments here
        # ...
)

Then in the compute_metrics() function, you can use inputs.

compute_metrics(pred):
       # do something with pred.inputs

Thank all guys above here. Cheers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants