tensorflow random forest ploting errors

163 Views Asked by At

running jupyter on anaconda mac/m2

after fitting the training data

rf = tfdf.keras.RandomForestModel(task = tfdf.keras.Task.REGRESSION)
rf.compile(metrics=["mse"])

rf.fit(x=train_ds)

i want to vizualise the model with the following code, but nothing is displayed

tfdf.model_plotter.plot_model_in_colab(rf, tree_idx=0, max_depth=3)

can i please have a suggestion or recommendation about what to do?

yep!(i tried chatgpt) it wrote the same code several times or a variation and still nothing.

according to chatgpt i have all the dependences installed

1

There are 1 best solutions below

1
On BEST ANSWER

TF-DF author here.

Unfortunately, interactive plotting with TF-DF only works in Colab, not in IPython, since the two have slightly different Javascript integrations. Currently, you have two options:

  1. Use non-interactive text plots:
> print(model_1.make_inspector().extract_tree(1))
(bill_depth_mm >= 16.350000381469727; miss=True, score=0.4877108931541443)
    ├─(pos)─ (bill_length_mm >= 43.05000305175781; miss=True, score=0.4372641444206238)
    │        ├─(pos)─ (body_mass_g >= 4125.0; miss=True, score=0.52157062292099)
    │        │        ├─(pos)─ (flipper_length_mm >= 199.01458740234375; miss=True, score=0.5047621130943298)
    │        │        │    ...
    │        │        └─(neg)─ ProbabilityValue([0.0, 0.0, 1.0],n=38.0) (idx=5)
    │        └─(neg)─ (bill_depth_mm >= 17.450000762939453; miss=False, score=0.015847451984882355)
    │                 ├─(pos)─ ProbabilityValue([1.0, 0.0, 0.0],n=68.0) (idx=4)
    │                 └─(neg)─ (bill_length_mm >= 38.900001525878906; miss=True, score=0.0711795762181282)
    │                      ...
    └─(neg)─ (body_mass_g >= 3750.0; miss=True, score=0.20150887966156006)
             ├─(pos)─ ProbabilityValue([0.0, 1.0, 0.0],n=93.0) (idx=1)
             └─(neg)─ ProbabilityValue([1.0, 0.0, 0.0],n=5.0) (idx=0)
  1. If you want beautiful visualizations with lots of options and lots of information, you can use dtreeviz. There is a tutorial on the TensorFlow website that explains in detail how to use it with TF-DF

  2. Extract the HTML that TF-DF produces yourself and use it in a compatible viewer:

html = tfdf.model_plotter.plot_model(rf, tree_idx=0, max_depth=3)