I am currently working on a project where I would like to fine-tune a whisper model via the HuggingFace Transformers library. So far, the finetuning has worked well, however, I have become stuck on how I should choose the hyperparameters for this finetuning.
I have tried using RayTune (by following the guide on the huggingface website), which seems to work well for tuning models like Bert. However, it does not seem to work when using it on a tansformers Seq2Seq trainer, as it keeps giving me the error below.
I was wondering if anyone knew of anyway to integrate RayTune or any other hyperparameter optimisation library with transformers' Seq2Seq trainer? Thank you very much for the help!
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py in put_object(self, value, object_ref, owner_address)
752 try:
--> 753 serialized_value = self.get_serialization_context().serialize(value)
754 except TypeError as e:
16 frames
/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py in serialize(self, value)
493 else:
--> 494 return self._serialize_to_msgpack(value)
/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py in _serialize_to_msgpack(self, value)
471 metadata = ray_constants.OBJECT_METADATA_TYPE_PYTHON
--> 472 pickle5_serialized_object = self._serialize_to_pickle5(
473 metadata, python_objects
/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py in _serialize_to_pickle5(self, metadata, value)
424 self.get_and_clear_contained_object_refs()
--> 425 raise e
426 finally:
/usr/local/lib/python3.10/dist-packages/ray/_private/serialization.py in _serialize_to_pickle5(self, metadata, value)
419 self.set_in_band_serialization()
--> 420 inband = pickle.dumps(
421 value, protocol=5, buffer_callback=writer.buffer_callback
/usr/local/lib/python3.10/dist-packages/ray/cloudpickle/cloudpickle_fast.py in dumps(obj, protocol, buffer_callback)
87 cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback)
---> 88 cp.dump(obj)
89 return file.getvalue()
/usr/local/lib/python3.10/dist-packages/ray/cloudpickle/cloudpickle_fast.py in dump(self, obj)
732 try:
--> 733 return Pickler.dump(self, obj)
734 except RuntimeError as e:
TypeError: cannot pickle 'torch._C.Generator' object
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
<ipython-input-26-25e54ca2b4c4> in <cell line: 1>()
----> 1 trainer.hyperparameter_search(
2 direction="minimise",
3 backend="ray",
4 n_trials=10 # number of trials
5 )
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in hyperparameter_search(self, hp_space, compute_objective, n_trials, direction, backend, hp_name, **kwargs)
2798 self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
2799
-> 2800 best_run = backend_obj.run(self, n_trials, direction, **kwargs)
2801
2802 self.hp_search_backend = None
/usr/local/lib/python3.10/dist-packages/transformers/hyperparameter_search.py in run(self, trainer, n_trials, direction, **kwargs)
85
86 def run(self, trainer, n_trials: int, direction: str, **kwargs):
---> 87 return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
88
89 def default_hp_space(self, trial):
/usr/local/lib/python3.10/dist-packages/transformers/integrations/integration_utils.py in run_hp_search_ray(trainer, n_trials, direction, **kwargs)
330 )
331
--> 332 trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)
333
334 @functools.wraps(trainable)
/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/util.py in with_parameters(trainable, **kwargs)
107 prefix = f"{str(trainable)}_"
108 for k, v in kwargs.items():
--> 109 parameter_registry.put(prefix + k, v)
110
111 trainable_name = getattr(trainable, "__name__", "tune_with_parameters")
/usr/local/lib/python3.10/dist-packages/ray/tune/registry.py in put(self, k, v)
294 self.to_flush[k] = v
295 if ray.is_initialized():
--> 296 self.flush()
297
298 def get(self, k):
/usr/local/lib/python3.10/dist-packages/ray/tune/registry.py in flush(self)
306 self.references[k] = v
307 else:
--> 308 self.references[k] = ray.put(v)
309 self.to_flush.clear()
/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py in auto_init_wrapper(*args, **kwargs)
20 def auto_init_wrapper(*args, **kwargs):
21 auto_init_ray()
---> 22 return fn(*args, **kwargs)
23
24 return auto_init_wrapper
/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py in wrapper(*args, **kwargs)
101 if func.__name__ != "init" or is_client_mode_enabled_by_default:
102 return getattr(ray, func.__name__)(*args, **kwargs)
--> 103 return func(*args, **kwargs)
104
105 return wrapper
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py in put(value, _owner)
2695 with profiling.profile("ray.put"):
2696 try:
-> 2697 object_ref = worker.put_object(value, owner_address=serialize_owner_address)
2698 except ObjectStoreFullError:
2699 logger.info(
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py in put_object(self, value, object_ref, owner_address)
760 f"{sio.getvalue()}"
761 )
--> 762 raise TypeError(msg) from e
763 # This *must* be the first place that we construct this python
764 # ObjectRef because an entry with 0 local references is created when
TypeError: Could not serialize the put value <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>:
================================================================================
Checking Serializability of <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>
================================================================================
!!! FAIL serialization: cannot pickle 'torch._C.Generator' object
Serializing 'compute_metrics' <function compute_metrics at 0x7f9809cd2a70>...
Serializing 'compute_objective' <function default_compute_objective at 0x7f9818ea2b00>...
Serializing 'get_optimizer_cls_and_kwargs' <function Trainer.get_optimizer_cls_and_kwargs at 0x7f98137eedd0>...
Serializing 'load_generation_config' <function Seq2SeqTrainer.load_generation_config at 0x7f98137fcdc0>...
Serializing 'model_init' <function model_init at 0x7f9809c9f0a0>...
Serializing '_activate_neftune' <bound method Trainer._activate_neftune of <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>>...
!!! FAIL serialization: cannot pickle 'torch._C.Generator' object
Serializing '__func__' <function Trainer._activate_neftune at 0x7f98137ee440>...
WARNING: Did not find non-serializable object in <bound method Trainer._activate_neftune of <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>>. This may be an oversight.
================================================================================
Variable:
FailTuple(_activate_neftune [obj=<bound method Trainer._activate_neftune of <transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>>, parent=<transformers.trainer_seq2seq.Seq2SeqTrainer object at 0x7f9809c1f790>])
was found to be non-serializable. There may be multiple other undetected variables that were non-serializable.
Consider either removing the instantiation/imports of these variables or moving the instantiation into the scope of the function/class.
================================================================================
Check https://docs.ray.io/en/master/ray-core/objects/serialization.html#troubleshooting for more information.
If you have any suggestions on how to improve this error message, please reach out to the Ray developers on github.com/ray-project/ray/issues/
================================================================================