How to optimise Hyperparameters for Whisper finetuning?

34 Views Asked by At

I am currently working on a project where I would like to fine-tune a whisper model via the HuggingFace Transformers library. So far, the finetuning has worked well, however, I have become stuck on how I should choose the hyperparameters for this finetuning.

I have tried using RayTune (by following the guide on the huggingface website), which seems to work well for tuning models like Bert. However, it does not seem to work when using it on a tansformers Seq2Seq trainer, as it keeps giving me the error below.

I was wondering if anyone knew of anyway to integrate RayTune or any other hyperparameter optimisation library with transformers' Seq2Seq trainer? Thank you very much for the help!

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py in put_object(self, value, object_ref, owner_address)
    752         try:
--> 753             serialized_value = self.get_serialization_context().serialize(value)
    754         except TypeError as e:

16 frames
/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py in serialize(self, value)
    493         else:
--> 494             return self._serialize_to_msgpack(value)

/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py in _serialize_to_msgpack(self, value)
    471             metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
--> 472             pickle5_serialized_object = self._serialize_to_pickle5(
    473                 metadata, python_objects

/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py in _serialize_to_pickle5(self, metadata, value)
    424             self.get_and_clear_contained_object_refs()
--> 425             raise e
    426         finally:

/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py in _serialize_to_pickle5(self, metadata, value)
    419             self.set_in_band_serialization()
--> 420             inband = pickle.dumps(
    421                 value, protocol=5, buffer_callback=writer.buffer_callback

/usr/local/lib/python3.10/dist-packages/ray/cloudpickle/cloudpickle_fast.py in dumps(obj, protocol, buffer_callback)
     87             cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback)
---> 88             cp.dump(obj)
     89             return file.getvalue()

/usr/local/lib/python3.10/dist-packages/ray/cloudpickle/cloudpickle_fast.py in dump(self, obj)
    732         try:
--> 733             return Pickler.dump(self, obj)
    734         except RuntimeError as e:

TypeError: cannot pickle 'torch._C.Generator' object

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
<ipython-input-26-25e54ca2b4c4> in <cell line: 1>()
----> 1 trainer.hyperparameter_search(
      2     direction="minimise",
      3     backend="ray",
      4     n_trials=10 # number of trials
      5 )

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in hyperparameter_search(self, hp_space, compute_objective, n_trials, direction, backend, hp_name, **kwargs)
   2798         self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
   2799 
-> 2800         best_run = backend_obj.run(self, n_trials, direction, **kwargs)
   2801 
   2802         self.hp_search_backend = None

/usr/local/lib/python3.10/dist-packages/transformers/hyperparameter_search.py in run(self, trainer, n_trials, direction, **kwargs)
     85 
     86     def run(self, trainer, n_trials: int, direction: str, **kwargs):
---> 87         return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
     88 
     89     def default_hp_space(self, trial):

/usr/local/lib/python3.10/dist-packages/transformers/integrations/integration_utils.py in run_hp_search_ray(trainer, n_trials, direction, **kwargs)
    330             )
    331 
--> 332     trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)
    333 
    334     @functools.wraps(trainable)

/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/util.py in with_parameters(trainable, **kwargs)
    107     prefix = f"{str(trainable)}_"
    108     for k, v in kwargs.items():
--> 109         parameter_registry.put(prefix + k, v)
    110 
    111     trainable_name = getattr(trainable, "__name__", "tune_with_parameters")

/usr/local/lib/python3.10/dist-packages/ray/tune/registry.py in put(self, k, v)
    294         self.to_flush[k] = v
    295         if ray.is_initialized():
--> 296             self.flush()
    297 
    298     def get(self, k):

/usr/local/lib/python3.10/dist-packages/ray/tune/registry.py in flush(self)
    306                 self.references[k] = v
    307             else:
--> 308                 self.references[k] = ray.put(v)
    309         self.to_flush.clear()

/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py in auto_init_wrapper(*args, **kwargs)
     20     def auto_init_wrapper(*args, **kwargs):
     21         auto_init_ray()
---> 22         return fn(*args, **kwargs)
     23 
     24     return auto_init_wrapper

/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py in wrapper(*args, **kwargs)
    101             if func.__name__ != "init" or is_client_mode_enabled_by_default:
    102                 return getattr(ray, func.__name__)(*args, **kwargs)
--> 103         return func(*args, **kwargs)
    104 
    105     return wrapper

/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py in put(value, _owner)
   2695     with profiling.profile("ray.put"):
   2696         try:
-> 2697             object_ref = worker.put_object(value, owner_address=serialize_owner_address)
   2698         except ObjectStoreFullError:
   2699             logger.info(

/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py in put_object(self, value, object_ref, owner_address)
    760                 f"{sio.getvalue()}"
    761             )
--> 762             raise TypeError(msg) from e
    763         # This *must* be the first place that we construct this python
    764         # ObjectRef because an entry with 0 local references is created when

TypeError: Could not serialize the put value <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>:
================================================================================
Checking Serializability of <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>
================================================================================
!!! FAIL serialization: cannot pickle 'torch._C.Generator' object
    Serializing 'compute_metrics' <function compute_metrics at 0x7f9809cd2a70>...
    Serializing 'compute_objective' <function default_compute_objective at 0x7f9818ea2b00>...
    Serializing 'get_optimizer_cls_and_kwargs' <function Trainer.get_optimizer_cls_and_kwargs at 0x7f98137eedd0>...
    Serializing 'load_generation_config' <function Seq2SeqTrainer.load_generation_config at 0x7f98137fcdc0>...
    Serializing 'model_init' <function model_init at 0x7f9809c9f0a0>...
    Serializing '_activate_neftune' <bound method Trainer._activate_neftune of <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>>...
    !!! FAIL serialization: cannot pickle 'torch._C.Generator' object
        Serializing '__func__' <function Trainer._activate_neftune at 0x7f98137ee440>...
    WARNING: Did not find non-serializable object in <bound method Trainer._activate_neftune of <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>>. This may be an oversight.
================================================================================
Variable: 

    FailTuple(_activate_neftune [obj=<bound method Trainer._activate_neftune of <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>>, parent=<transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>])

was found to be non-serializable. There may be multiple other undetected variables that were non-serializable. 
Consider either removing the instantiation/imports of these variables or moving the instantiation into the scope of the function/class. 
================================================================================
Check https://docs.ray.io/en/master/ray-core/objects/serialization.html#troubleshooting for more information.
If you have any suggestions on how to improve this error message, please reach out to the Ray developers on github.com/ray-project/ray/issues/
================================================================================

0

There are 0 best solutions below