{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# How to add new CondDensityEstimators, CauseClusterers, and EffectClusterers to CFL\n", "While the CFL software package comes with pre-implemented (either by us or \n", "Scikit-learn) models for conditional density estimation and clustering, it has\n", "also been designed to make it easy to try out new models. To do so, there are\n", "two main steps:\n", "- Make a python class for your model that inherits from the appropriate\n", " abstract class (either CDEModel, CCModel, or ECModel)\n", "- pass it in to your Experiment" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from cfl import Experiment\n", "from cfl.cond_density_estimation import CDEModel" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10000, 5)\n", "(10000, 3)\n" ] } ], "source": [ "# generate toy data\n", "data_info = {'X_dims' : (10000, 5),\n", " 'Y_dims' : (10000, 3),\n", " 'Y_type' : 'continuous'}\n", "X = np.random.normal(size=data_info['X_dims'])\n", "Y = np.random.normal(size=data_info['Y_dims'])\n", "print(X.shape)\n", "print(Y.shape)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# make a new conditional density estimator that inherits CDEModel. Your class\n", "# must implement all methods specified by CDEModel\n", "class MyCDE(CDEModel):\n", " def __init__(self, data_info, model_params):\n", " self.data_info = data_info\n", " self.model_params = model_params\n", "\n", " def train(self, dataset, prev_results=None):\n", " pyx = np.random.normal(size=dataset.get_Y().shape)\n", " return {'pyx' : pyx}\n", " \n", " def predict(self, dataset, prev_results=None):\n", " pyx = np.random.normal(size=dataset.get_Y().shape)\n", " return {'pyx' : pyx}\n", " \n", " def load_model(self, path):\n", " pass\n", "\n", " def save_model(self, path):\n", " pass\n", " \n", " def get_model_params(self):\n", " return self.model_params" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Block: model_params not specified in input, defaulting to {}\n", "Block: verbose not specified in input, defaulting to 1\n", "Block: tune not specified in input, defaulting to False\n", "Block: user_input not specified in input, defaulting to True\n", "Block: verbose not specified in input, defaulting to 1\n" ] } ], "source": [ "# MyCDE can be passed in as the value for the 'model' key in CDE_params, \n", "# instead of a string name for pre-defined model\n", "CDE_params = {'model' : MyCDE(data_info, model_params={})}\n", "\n", "CC_params = {'model' : 'KMeans',\n", " 'model_params' : {'n_clusters' : 2}}\n", " \n", "block_names = ['CondDensityEstimator', 'CauseClusterer']\n", "block_params = [CDE_params, CC_params]\n", "my_exp = Experiment(X_train=X, Y_train=Y, data_info=data_info, \n", " block_names=block_names, block_params=block_params, results_path=None)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "#################### Beginning CFL Experiment training. ####################\n", "Beginning CondDensityEstimator training...\n", "CondDensityEstimator training complete.\n", "Beginning CauseClusterer training...\n", "CauseClusterer training complete.\n", "Experiment training complete.\n" ] }, { "data": { "text/plain": [ "{'CondDensityEstimator': {'pyx': array([[ 1.61736868, -0.83274787, 1.79892445],\n", " [ 0.93879797, 0.55829648, -1.65463966],\n", " [-0.53689253, -0.332663 , 1.45926297],\n", " ...,\n", " [ 0.39492332, -0.39107461, -0.41794009],\n", " [-0.30334537, -1.12357842, 0.44379039],\n", " [-1.05910614, -0.58523058, 0.2388775 ]])},\n", " 'CauseClusterer': {'x_lbls': array([0, 0, 1, ..., 0, 1, 1], dtype=int32)}}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_exp.train()" ] } ], "metadata": { "interpreter": { "hash": "a3e2eae926291f18d1afe7b431ab5303b6e9ec8df44d60cf0ff75f0c781aceac" }, "kernelspec": { "display_name": "Python 3.7.9 ('cfl_env')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }