🌲 MetaTree 🌲

Mar 21, 2024 | Educational

Learning a Decision Tree Algorithm with Transformers (Zhuang et al. 2024)

MetaTree is a transformer-based decision tree algorithm. It learns from classical decision tree algorithms (greedy algorithm CART, optimal algorithm GOSDT), for better generalization capabilities.

Quickstart: Using MetaTree to Generate Decision Tree Models

Follow these steps to get started with MetaTree:

  • Model Availability: You can find the MetaTree model on Hugging Face.
  • Install MetaTree:
    pip install metatreelib

    Alternatively: Clone the repository and install it:

    git clone https://github.com/EvanZhuang/MetaTree
    cd MetaTree
    pip install -e .

Applying MetaTree to Your Datasets

The following code demonstrates how to use MetaTree to generate a decision tree model:

from metatree.model_metatree import LlamaForMetaTree as MetaTree
from metatree.decision_tree_class import DecisionTree, DecisionTreeForest
from metatree.run_train import preprocess_dimension_patch
from transformers import AutoConfig
import imodels
import sklearn.model_selection
import numpy as np
import random
import torch

# Initialize Model
model_name_or_path = "yzhuang/MetaTree"
config = AutoConfig.from_pretrained(model_name_or_path)
model = MetaTree.from_pretrained(model_name_or_path, config=config)
decision_tree_forest = DecisionTreeForest()

# Load Datasets
X, y, feature_names = imodels.get_clean_dataset("fico", data_source=imodels)
print("Dataset Shapes: X={}, y={}, Num of Classes={}".format(X.shape, y.shape, len(set(y))))

train_idx, test_idx = sklearn.model_selection.train_test_split(range(X.shape[0]), test_size=0.3, random_state=42)
# Dimension Subsampling
feature_idx = np.random.choice(X.shape[1], 10, replace=False)
X = X[:, feature_idx]
test_X, test_y = X[test_idx], y[test_idx]

# Sample Train and Test Data
subset_idx = random.sample(train_idx, 256)
train_X, train_y = X[subset_idx], y[subset_idx]
input_x = torch.tensor(train_X, dtype=torch.float32)
input_y = torch.nn.functional.one_hot(torch.tensor(train_y)).float()
batch = {'input_x': input_x, 'input_y': input_y, 'input_y_clean': input_y}

batch = preprocess_dimension_patch(batch, n_feature=10, n_class=10)
model.depth = 2
outputs = model.generate_decision_tree(batch['input_x'], batch['input_y'], depth=model.depth)

decision_tree_forest.add_tree(DecisionTree(
    auto_dims=outputs.metatree_dimensions,
    auto_thresholds=outputs.tentative_splits,
    input_x=batch['input_x'],
    input_y=batch['input_y'],
    depth=model.depth
))
print("Decision Tree Features: ", [x.argmax(dim=-1) for x in outputs.metatree_dimensions])
print("Decision Tree Thresholds: ", outputs.tentative_splits)

# Inference with the decision tree model
tree_pred = decision_tree_forest.predict(torch.tensor(test_X, dtype=torch.float32))
accuracy = accuracy_score(test_y, tree_pred.argmax(dim=-1).squeeze(0))
print("MetaTree Test Accuracy: ", accuracy)

Understanding the Code: An Analogy

Imagine you are a chef preparing a gourmet meal. You need to gather your ingredients, combine them in just the right proportions, and then cook them to perfection.

  • Ingredient Selection (Model Initialization): Just like choosing your cooking equipment (in this case, transformers and decision trees), the code initializes a model to create the perfect dish.
  • Gathering Ingredients (Loading Datasets): Here, you bring in different datasets as your ingredients, ensuring they are fresh and ready for use.
  • Mixing and Matching (Subsampling and Preparation): The code samples features and dimensions similar to how a chef might decide which ingredients to mix together, ensuring a balanced flavor experience.
  • Cooking (Training the Model): Finally, the model is trained akin to cooking, where all the components come together to create the final product—the decision tree.

Example Usage

For a complete example of using MetaTree, check out the example notebook.

Troubleshooting

If you encounter any issues while working with MetaTree, here are some troubleshooting tips:

  • Ensure that you have installed all necessary libraries, especially imodels and transformers.
  • Double-check your dataset paths and formats—they must align with what is expected by the model.
  • If you experience memory issues during model training, consider reducing the size of your datasets or using a machine with more resources.
  • For more insights, updates, or to collaborate on AI development projects, stay connected with fxis.ai.

Conclusion

At fxis.ai, we believe that such advancements are crucial for the future of AI, as they enable more comprehensive and effective solutions. Our team is continually exploring new methodologies to push the envelope in artificial intelligence, ensuring that our clients benefit from the latest technological innovations.

Stay Informed with the Newest F(x) Insights and Blogs

Tech News and Blog Highlights, Straight to Your Inbox