Creating Tensorflow decision forests from individual trees

17 Views Asked by At

Is possible to build a decision forest with TensorFlow from many individual decision trees? Also, remove and add individual trees that are in the decision forest based on some performance criteria? Let's think about if several workers train their trees and at the end of the process they send their trees for aggregation in the decision forest. I plan to use this model in federated learning, where the server has a decision forest and the devices send their trees for aggregation.

I'm finding it a little difficult to understand whether this solution is feasible.

1

There are 1 best solutions below

0
On

TF-DF developer here.

It's possible (but possibly quite slow) to do this in TF-DF, though this is a bit outside "normal" use cases. The key tools you need are the "model inspector" (to extract trees from the model builder) and the "model builder" (to add trees to the model). There is a tutorial for these tools here. I'll try to outline the main steps.

A TF-DF tree can be represented with the tfdf.py_tree.tree.Tree class. From an existing model, you can get extract a tree like this:

tree_idx = 4  # Which tree to extract
insp = source_model.make_inspector()  # source_model is a TF-DF model
t = inspector.extract_tree(tree_idx=0)

Now we build a model using the tree:

# Create some model
builder = tfdf.builder.RandomForestBuilder(
    path="/tmp/manual_model",
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="color", classes=["red", "blue", "green"]))
builder.add_tree(t)
# ... possibly add more trees
builder.close()

# Now load the model
manual_model = tf_keras.models.load_model("/tmp/manual_model")

You can also use TF-DF's sister library YDF which handles this task much faster and a easier. The main advantage is that YDF allows you to modify the model directly in-memory, while TF-DF only modifies the model files on disk. The YDF model can be exported to TF-DF, so in many cases this will be the better choice. Documentation for modifying models is here.

In ydf, the above code will be

import ydf  # Import the library
tree_idx = 4  # Which tree to extract
tree = source_model.get_tree(tree_idx)

Now we build a model using the tree. Note that here, the dataset is only used for defining the input features, not for training (since the model is created with 0 trees).

manual_model = model = ydf.RandomForestLearner(label="label", num_trees=0, task=ydf.Task.REGRESSION).train(dataset)
manual_model.add_tree(tree)  # Add the tree.