import opendatasets as od
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import os
%matplotlib inline

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 150)
sns.set_style('darkgrid')
matplotlib.rcParams['font.size'] = 14
matplotlib.rcParams['figure.figsize'] = (10, 6)
matplotlib.rcParams['figure.facecolor'] = '#00000000'
os.listdir('weather-dataset-rattle-package')
['weatherAUS.csv']
raw_df = pd.read_csv('weather-dataset-rattle-package/weatherAUS.csv')
raw_df.head(10)
Date Location MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustDir WindGustSpeed WindDir9am WindDir3pm WindSpeed9am WindSpeed3pm Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm RainToday RainTomorrow
0 2008-12-01 Albury 13.4 22.9 0.6 NaN NaN W 44.0 W WNW 20.0 24.0 71.0 22.0 1007.7 1007.1 8.0 NaN 16.9 21.8 No No
1 2008-12-02 Albury 7.4 25.1 0.0 NaN NaN WNW 44.0 NNW WSW 4.0 22.0 44.0 25.0 1010.6 1007.8 NaN NaN 17.2 24.3 No No
2 2008-12-03 Albury 12.9 25.7 0.0 NaN NaN WSW 46.0 W WSW 19.0 26.0 38.0 30.0 1007.6 1008.7 NaN 2.0 21.0 23.2 No No
3 2008-12-04 Albury 9.2 28.0 0.0 NaN NaN NE 24.0 SE E 11.0 9.0 45.0 16.0 1017.6 1012.8 NaN NaN 18.1 26.5 No No
4 2008-12-05 Albury 17.5 32.3 1.0 NaN NaN W 41.0 ENE NW 7.0 20.0 82.0 33.0 1010.8 1006.0 7.0 8.0 17.8 29.7 No No
5 2008-12-06 Albury 14.6 29.7 0.2 NaN NaN WNW 56.0 W W 19.0 24.0 55.0 23.0 1009.2 1005.4 NaN NaN 20.6 28.9 No No
6 2008-12-07 Albury 14.3 25.0 0.0 NaN NaN W 50.0 SW W 20.0 24.0 49.0 19.0 1009.6 1008.2 1.0 NaN 18.1 24.6 No No
7 2008-12-08 Albury 7.7 26.7 0.0 NaN NaN W 35.0 SSE W 6.0 17.0 48.0 19.0 1013.4 1010.1 NaN NaN 16.3 25.5 No No
8 2008-12-09 Albury 9.7 31.9 0.0 NaN NaN NNW 80.0 SE NW 7.0 28.0 42.0 9.0 1008.9 1003.6 NaN NaN 18.3 30.2 No Yes
9 2008-12-10 Albury 13.1 30.1 1.4 NaN NaN W 28.0 S SSE 15.0 11.0 58.0 27.0 1007.0 1005.7 NaN NaN 20.1 28.2 Yes No
raw_df.shape
(145460, 23)
raw_df.info() # to check column types of dataset
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 145460 entries, 0 to 145459
Data columns (total 23 columns):
 #   Column         Non-Null Count   Dtype  
---  ------         --------------   -----  
 0   Date           145460 non-null  object 
 1   Location       145460 non-null  object 
 2   MinTemp        143975 non-null  float64
 3   MaxTemp        144199 non-null  float64
 4   Rainfall       142199 non-null  float64
 5   Evaporation    82670 non-null   float64
 6   Sunshine       75625 non-null   float64
 7   WindGustDir    135134 non-null  object 
 8   WindGustSpeed  135197 non-null  float64
 9   WindDir9am     134894 non-null  object 
 10  WindDir3pm     141232 non-null  object 
 11  WindSpeed9am   143693 non-null  float64
 12  WindSpeed3pm   142398 non-null  float64
 13  Humidity9am    142806 non-null  float64
 14  Humidity3pm    140953 non-null  float64
 15  Pressure9am    130395 non-null  float64
 16  Pressure3pm    130432 non-null  float64
 17  Cloud9am       89572 non-null   float64
 18  Cloud3pm       86102 non-null   float64
 19  Temp9am        143693 non-null  float64
 20  Temp3pm        141851 non-null  float64
 21  RainToday      142199 non-null  object 
 22  RainTomorrow   142193 non-null  object 
dtypes: float64(16), object(7)
memory usage: 25.5+ MB
raw_df.dropna(subset=['RainTomorrow'], inplace=True)
raw_df.head(2)
Date Location MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustDir WindGustSpeed WindDir9am WindDir3pm WindSpeed9am WindSpeed3pm Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm RainToday RainTomorrow
0 2008-12-01 Albury 13.4 22.9 0.6 NaN NaN W 44.0 W WNW 20.0 24.0 71.0 22.0 1007.7 1007.1 8.0 NaN 16.9 21.8 No No
1 2008-12-02 Albury 7.4 25.1 0.0 NaN NaN WNW 44.0 NNW WSW 4.0 22.0 44.0 25.0 1010.6 1007.8 NaN NaN 17.2 24.3 No No
raw_df.shape # shape has become 142193
(142193, 23)

Training Validation and Test Sets

plt.title("no.of Rows per Year")
sns.countplot(x=pd.to_datetime(raw_df.Date).dt.year);
year = pd.to_datetime(raw_df.Date).dt.year

train_df = raw_df[year<2015]
val_df = raw_df[year==2015]
test_df = raw_df[year>2015]

print(train_df.shape, val_df.shape, test_df.shape)
(98988, 23) (17231, 23) (25974, 23)

Input and Target Columns

input_cols = list(train_df.columns)[1:-1]
target_cols = 'RainTomorrow'
target_cols
'RainTomorrow'
input_cols
['Location',
 'MinTemp',
 'MaxTemp',
 'Rainfall',
 'Evaporation',
 'Sunshine',
 'WindGustDir',
 'WindGustSpeed',
 'WindDir9am',
 'WindDir3pm',
 'WindSpeed9am',
 'WindSpeed3pm',
 'Humidity9am',
 'Humidity3pm',
 'Pressure9am',
 'Pressure3pm',
 'Cloud9am',
 'Cloud3pm',
 'Temp9am',
 'Temp3pm',
 'RainToday']
train_inputs = train_df[input_cols].copy()
train_targets = train_df[target_cols].copy()

val_inputs = val_df[input_cols].copy()
val_targets = val_df[target_cols].copy()

test_inputs = test_df[input_cols].copy()
test_targets = test_df[target_cols].copy()
numeric_cols = train_inputs.select_dtypes(include=np.number).columns.tolist()
categorical_cols = train_inputs.select_dtypes('object').columns.tolist()
print(numeric_cols)
['MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine', 'WindGustSpeed', 'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm', 'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am', 'Temp3pm']
print(categorical_cols)
['Location', 'WindGustDir', 'WindDir9am', 'WindDir3pm', 'RainToday']

Imputing Missing Numeric Values

train_inputs[numeric_cols].isna().sum().sort_values(ascending=False)
Sunshine         40696
Evaporation      37110
Cloud3pm         36766
Cloud9am         35764
Pressure9am       9345
Pressure3pm       9309
WindGustSpeed     6902
Humidity9am       1265
Humidity3pm       1186
WindSpeed3pm      1140
WindSpeed9am      1133
Rainfall          1000
Temp9am            783
Temp3pm            663
MinTemp            434
MaxTemp            198
dtype: int64
from sklearn.impute import SimpleImputer
imputer = SimpleImputer(strategy = 'mean').fit(raw_df[numeric_cols]) # imputer will figureout the avg for each of cols
train_inputs[numeric_cols] = imputer.transform(train_inputs[numeric_cols]) # fill empty data
val_inputs[numeric_cols] = imputer.transform(val_inputs[numeric_cols])
test_inputs[numeric_cols] = imputer.transform(test_inputs[numeric_cols])
train_inputs[numeric_cols].isna().sum()
MinTemp          0
MaxTemp          0
Rainfall         0
Evaporation      0
Sunshine         0
WindGustSpeed    0
WindSpeed9am     0
WindSpeed3pm     0
Humidity9am      0
Humidity3pm      0
Pressure9am      0
Pressure3pm      0
Cloud9am         0
Cloud3pm         0
Temp9am          0
Temp3pm          0
dtype: int64

Scaling Numeric Features

from sklearn.preprocessing import MinMaxScaler
val_inputs.describe().loc[['min', 'max']]
MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustSpeed WindSpeed9am WindSpeed3pm Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm
min -8.2 -3.2 0.0 0.0 0.0 7.0 0.0 0.0 4.0 0.0 988.1 982.2 0.0 0.0 -6.2 -4.0
max 31.9 45.4 247.2 70.4 14.5 135.0 87.0 74.0 100.0 100.0 1039.3 1037.3 8.0 8.0 37.5 42.8
scaler = MinMaxScaler().fit(raw_df[numeric_cols])
train_inputs[numeric_cols] = scaler.transform(train_inputs[numeric_cols])
val_inputs[numeric_cols] = scaler.transform(val_inputs[numeric_cols])
test_inputs[numeric_cols] = scaler.transform(test_inputs[numeric_cols])
val_inputs.describe().loc[['min', 'max']]
MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustSpeed WindSpeed9am WindSpeed3pm Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm
min 0.007075 0.030246 0.000000 0.000000 0.0 0.007752 0.000000 0.000000 0.04 0.0 0.125620 0.0816 0.000000 0.000000 0.021097 0.026871
max 0.952830 0.948960 0.666307 0.485517 1.0 1.000000 0.669231 0.850575 1.00 1.0 0.971901 0.9632 0.888889 0.888889 0.943038 0.925144

Encoding Categorical Data

from sklearn.preprocessing import OneHotEncoder
train_df[categorical_cols].fillna('Unkown')
val_df[categorical_cols].fillna('Unkown')
test_df[categorical_cols].fillna('Unknown')
Location WindGustDir WindDir9am WindDir3pm RainToday
2498 Albury ENE Unknown ESE No
2499 Albury SSE SSE SE No
2500 Albury ENE ESE ENE Yes
2501 Albury SSE SE SSE Yes
2502 Albury ENE SE SSE Yes
... ... ... ... ... ...
145454 Uluru E ESE E No
145455 Uluru E SE ENE No
145456 Uluru NNW SE N No
145457 Uluru N SE WNW No
145458 Uluru SE SSE N No

25974 rows × 5 columns

encoder = OneHotEncoder(sparse=False, handle_unknown='ignore').fit(raw_df[categorical_cols])
encoded_cols = list(encoder.get_feature_names(categorical_cols))  
train_inputs[encoded_cols] = encoder.transform(train_inputs[categorical_cols])
val_inputs[encoded_cols] = encoder.transform(val_inputs[categorical_cols])
test_inputs[encoded_cols] = encoder.transform(test_inputs[categorical_cols])
print(encoded_cols)
['Location_Adelaide', 'Location_Albany', 'Location_Albury', 'Location_AliceSprings', 'Location_BadgerysCreek', 'Location_Ballarat', 'Location_Bendigo', 'Location_Brisbane', 'Location_Cairns', 'Location_Canberra', 'Location_Cobar', 'Location_CoffsHarbour', 'Location_Dartmoor', 'Location_Darwin', 'Location_GoldCoast', 'Location_Hobart', 'Location_Katherine', 'Location_Launceston', 'Location_Melbourne', 'Location_MelbourneAirport', 'Location_Mildura', 'Location_Moree', 'Location_MountGambier', 'Location_MountGinini', 'Location_Newcastle', 'Location_Nhil', 'Location_NorahHead', 'Location_NorfolkIsland', 'Location_Nuriootpa', 'Location_PearceRAAF', 'Location_Penrith', 'Location_Perth', 'Location_PerthAirport', 'Location_Portland', 'Location_Richmond', 'Location_Sale', 'Location_SalmonGums', 'Location_Sydney', 'Location_SydneyAirport', 'Location_Townsville', 'Location_Tuggeranong', 'Location_Uluru', 'Location_WaggaWagga', 'Location_Walpole', 'Location_Watsonia', 'Location_Williamtown', 'Location_Witchcliffe', 'Location_Wollongong', 'Location_Woomera', 'WindGustDir_E', 'WindGustDir_ENE', 'WindGustDir_ESE', 'WindGustDir_N', 'WindGustDir_NE', 'WindGustDir_NNE', 'WindGustDir_NNW', 'WindGustDir_NW', 'WindGustDir_S', 'WindGustDir_SE', 'WindGustDir_SSE', 'WindGustDir_SSW', 'WindGustDir_SW', 'WindGustDir_W', 'WindGustDir_WNW', 'WindGustDir_WSW', 'WindGustDir_nan', 'WindDir9am_E', 'WindDir9am_ENE', 'WindDir9am_ESE', 'WindDir9am_N', 'WindDir9am_NE', 'WindDir9am_NNE', 'WindDir9am_NNW', 'WindDir9am_NW', 'WindDir9am_S', 'WindDir9am_SE', 'WindDir9am_SSE', 'WindDir9am_SSW', 'WindDir9am_SW', 'WindDir9am_W', 'WindDir9am_WNW', 'WindDir9am_WSW', 'WindDir9am_nan', 'WindDir3pm_E', 'WindDir3pm_ENE', 'WindDir3pm_ESE', 'WindDir3pm_N', 'WindDir3pm_NE', 'WindDir3pm_NNE', 'WindDir3pm_NNW', 'WindDir3pm_NW', 'WindDir3pm_S', 'WindDir3pm_SE', 'WindDir3pm_SSE', 'WindDir3pm_SSW', 'WindDir3pm_SW', 'WindDir3pm_W', 'WindDir3pm_WNW', 'WindDir3pm_WSW', 'WindDir3pm_nan', 'RainToday_No', 'RainToday_Yes', 'RainToday_nan']
train_inputs.head(10)
Location MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustDir WindGustSpeed WindDir9am WindDir3pm WindSpeed9am WindSpeed3pm Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm RainToday Location_Adelaide Location_Albany Location_Albury Location_AliceSprings Location_BadgerysCreek Location_Ballarat Location_Bendigo Location_Brisbane Location_Cairns Location_Canberra Location_Cobar Location_CoffsHarbour Location_Dartmoor Location_Darwin Location_GoldCoast Location_Hobart Location_Katherine Location_Launceston Location_Melbourne Location_MelbourneAirport Location_Mildura Location_Moree Location_MountGambier Location_MountGinini Location_Newcastle Location_Nhil Location_NorahHead Location_NorfolkIsland Location_Nuriootpa Location_PearceRAAF Location_Penrith Location_Perth Location_PerthAirport Location_Portland Location_Richmond Location_Sale Location_SalmonGums Location_Sydney Location_SydneyAirport Location_Townsville Location_Tuggeranong Location_Uluru Location_WaggaWagga Location_Walpole Location_Watsonia Location_Williamtown Location_Witchcliffe Location_Wollongong Location_Woomera WindGustDir_E WindGustDir_ENE WindGustDir_ESE WindGustDir_N WindGustDir_NE WindGustDir_NNE WindGustDir_NNW WindGustDir_NW WindGustDir_S WindGustDir_SE WindGustDir_SSE WindGustDir_SSW WindGustDir_SW WindGustDir_W WindGustDir_WNW WindGustDir_WSW WindGustDir_nan WindDir9am_E WindDir9am_ENE WindDir9am_ESE WindDir9am_N WindDir9am_NE WindDir9am_NNE WindDir9am_NNW WindDir9am_NW WindDir9am_S WindDir9am_SE WindDir9am_SSE WindDir9am_SSW WindDir9am_SW WindDir9am_W WindDir9am_WNW WindDir9am_WSW WindDir9am_nan WindDir3pm_E WindDir3pm_ENE WindDir3pm_ESE WindDir3pm_N WindDir3pm_NE WindDir3pm_NNE WindDir3pm_NNW WindDir3pm_NW WindDir3pm_S WindDir3pm_SE WindDir3pm_SSE WindDir3pm_SSW WindDir3pm_SW WindDir3pm_W WindDir3pm_WNW WindDir3pm_WSW WindDir3pm_nan RainToday_No RainToday_Yes RainToday_nan
0 Albury 0.516509 0.523629 0.001617 0.037723 0.525852 W 0.294574 W WNW 0.153846 0.275862 0.71 0.22 0.449587 0.4800 0.888889 0.500352 0.508439 0.522073 No 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0
1 Albury 0.375000 0.565217 0.000000 0.037723 0.525852 WNW 0.294574 NNW WSW 0.030769 0.252874 0.44 0.25 0.497521 0.4912 0.493021 0.500352 0.514768 0.570058 No 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0
2 Albury 0.504717 0.576560 0.000000 0.037723 0.525852 WSW 0.310078 W WSW 0.146154 0.298851 0.38 0.30 0.447934 0.5056 0.493021 0.222222 0.594937 0.548944 No 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0
3 Albury 0.417453 0.620038 0.000000 0.037723 0.525852 NE 0.139535 SE E 0.084615 0.103448 0.45 0.16 0.613223 0.5712 0.493021 0.500352 0.533755 0.612284 No 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
4 Albury 0.613208 0.701323 0.002695 0.037723 0.525852 W 0.271318 ENE NW 0.053846 0.229885 0.82 0.33 0.500826 0.4624 0.777778 0.888889 0.527426 0.673704 No 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
5 Albury 0.544811 0.652174 0.000539 0.037723 0.525852 WNW 0.387597 W W 0.146154 0.275862 0.55 0.23 0.474380 0.4528 0.493021 0.500352 0.586498 0.658349 No 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0
6 Albury 0.537736 0.563327 0.000000 0.037723 0.525852 W 0.341085 SW W 0.153846 0.275862 0.49 0.19 0.480992 0.4976 0.111111 0.500352 0.533755 0.575816 No 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0
7 Albury 0.382075 0.595463 0.000000 0.037723 0.525852 W 0.224806 SSE W 0.046154 0.195402 0.48 0.19 0.543802 0.5280 0.493021 0.500352 0.495781 0.593090 No 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0
8 Albury 0.429245 0.693762 0.000000 0.037723 0.525852 NNW 0.573643 SE NW 0.053846 0.321839 0.42 0.09 0.469421 0.4240 0.493021 0.500352 0.537975 0.683301 No 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
9 Albury 0.509434 0.659735 0.003774 0.037723 0.525852 W 0.170543 S SSE 0.115385 0.126437 0.58 0.27 0.438017 0.4576 0.493021 0.500352 0.575949 0.644914 Yes 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
X_train = train_inputs[numeric_cols + encoded_cols]
X_val = val_inputs[numeric_cols + encoded_cols]
X_test =  test_inputs[numeric_cols + encoded_cols]
X_test.head(10)
MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustSpeed WindSpeed9am WindSpeed3pm Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm Location_Adelaide Location_Albany Location_Albury Location_AliceSprings Location_BadgerysCreek Location_Ballarat Location_Bendigo Location_Brisbane Location_Cairns Location_Canberra Location_Cobar Location_CoffsHarbour Location_Dartmoor Location_Darwin Location_GoldCoast Location_Hobart Location_Katherine Location_Launceston Location_Melbourne Location_MelbourneAirport Location_Mildura Location_Moree Location_MountGambier Location_MountGinini Location_Newcastle Location_Nhil Location_NorahHead Location_NorfolkIsland Location_Nuriootpa Location_PearceRAAF Location_Penrith Location_Perth Location_PerthAirport Location_Portland Location_Richmond Location_Sale Location_SalmonGums Location_Sydney Location_SydneyAirport Location_Townsville Location_Tuggeranong Location_Uluru Location_WaggaWagga Location_Walpole Location_Watsonia Location_Williamtown Location_Witchcliffe Location_Wollongong Location_Woomera WindGustDir_E WindGustDir_ENE WindGustDir_ESE WindGustDir_N WindGustDir_NE WindGustDir_NNE WindGustDir_NNW WindGustDir_NW WindGustDir_S WindGustDir_SE WindGustDir_SSE WindGustDir_SSW WindGustDir_SW WindGustDir_W WindGustDir_WNW WindGustDir_WSW WindGustDir_nan WindDir9am_E WindDir9am_ENE WindDir9am_ESE WindDir9am_N WindDir9am_NE WindDir9am_NNE WindDir9am_NNW WindDir9am_NW WindDir9am_S WindDir9am_SE WindDir9am_SSE WindDir9am_SSW WindDir9am_SW WindDir9am_W WindDir9am_WNW WindDir9am_WSW WindDir9am_nan WindDir3pm_E WindDir3pm_ENE WindDir3pm_ESE WindDir3pm_N WindDir3pm_NE WindDir3pm_NNE WindDir3pm_NNW WindDir3pm_NW WindDir3pm_S WindDir3pm_SE WindDir3pm_SSE WindDir3pm_SSW WindDir3pm_SW WindDir3pm_W WindDir3pm_WNW WindDir3pm_WSW WindDir3pm_nan RainToday_No RainToday_Yes RainToday_nan
2498 0.681604 0.801512 0.000000 0.037723 0.525852 0.372093 0.000000 0.080460 0.46 0.17 0.543802 0.5136 0.777778 0.333333 0.702532 0.808061 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
2499 0.693396 0.725898 0.001078 0.037723 0.525852 0.341085 0.069231 0.195402 0.54 0.30 0.505785 0.5008 0.888889 0.888889 0.675105 0.712092 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
2500 0.634434 0.527410 0.005930 0.037723 0.525852 0.325581 0.084615 0.448276 0.62 0.67 0.553719 0.6032 0.888889 0.888889 0.611814 0.477927 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
2501 0.608491 0.538752 0.042049 0.037723 0.525852 0.255814 0.069231 0.195402 0.74 0.65 0.618182 0.6304 0.888889 0.888889 0.556962 0.518234 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
2502 0.566038 0.523629 0.018329 0.037723 0.525852 0.193798 0.046154 0.103448 0.92 0.63 0.591736 0.5888 0.888889 0.888889 0.514768 0.529750 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
2503 0.601415 0.621928 0.000539 0.037723 0.525852 0.255814 0.069231 0.126437 0.76 0.52 0.563636 0.5680 0.888889 0.888889 0.580169 0.596929 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
2504 0.587264 0.620038 0.000000 0.037723 0.525852 0.224806 0.153846 0.229885 0.46 0.31 0.609917 0.6176 0.493021 0.222222 0.592827 0.614203 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
2505 0.537736 0.689981 0.000000 0.037723 0.525852 0.139535 0.084615 0.068966 0.63 0.24 0.646281 0.6416 0.493021 0.888889 0.561181 0.654511 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
2506 0.594340 0.752363 0.000000 0.037723 0.525852 0.170543 0.084615 0.103448 0.52 0.24 0.629752 0.6144 0.493021 0.333333 0.662447 0.738964 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
2507 0.620283 0.790170 0.000000 0.037723 0.525852 0.271318 0.069231 0.195402 0.54 0.17 0.596694 0.5680 0.493021 0.500352 0.704641 0.798464 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0

Training and Visualizing Decision Trees

Training

from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier(random_state=42) # random state is provided to get same value each time
%%time
model.fit(X_train, train_targets)
CPU times: user 1.9 s, sys: 9.74 ms, total: 1.91 s
Wall time: 1.92 s
DecisionTreeClassifier(random_state=42)

Evaluation

from sklearn.metrics import accuracy_score, confusion_matrix
train_preds = model.predict(X_train)
train_preds
array(['No', 'No', 'No', ..., 'No', 'No', 'No'], dtype=object)
pd.value_counts(train_preds)
No     76707
Yes    22281
dtype: int64

Decision tree also returns probabilities of each prediction

train_probs = model.predict_proba(X_train)
train_probs
array([[1., 0.],
       [1., 0.],
       [1., 0.],
       ...,
       [1., 0.],
       [1., 0.],
       [1., 0.]])
train_targets
0         No
1         No
2         No
3         No
4         No
          ..
144548    No
144549    No
144550    No
144551    No
144552    No
Name: RainTomorrow, Length: 98988, dtype: object
accuracy_score(train_preds, train_targets)
0.9999797955307714
model.score(X_val, val_targets) # direct prediction on val inputs and compare accuracy

#only ~79%
0.7921188555510418
val_targets.value_counts() / len(val_targets)
No     0.788289
Yes    0.211711
Name: RainTomorrow, dtype: float64

It appears that the model has learned the training examples perfect, and doesn't generalize well to previously unseen examples. This phenomenon is called "overfitting", and reducing overfitting is one of the most important parts of any machine learning project.

Visualizing Tree

from sklearn.tree import plot_tree, export_text
plt.figure(figsize=(80, 40))
plot_tree(model, feature_names=X_train.columns, max_depth=2, filled=True)
[Text(2232.0, 1902.6000000000001, 'Humidity3pm <= 0.715\ngini = 0.349\nsamples = 98988\nvalue = [76705, 22283]'),
 Text(1116.0, 1359.0, 'Rainfall <= 0.004\ngini = 0.248\nsamples = 82418\nvalue = [70439, 11979]'),
 Text(558.0, 815.4000000000001, 'Sunshine <= 0.525\ngini = 0.198\nsamples = 69252\nvalue = [61538, 7714]'),
 Text(279.0, 271.79999999999995, '\n  (...)  \n'),
 Text(837.0, 271.79999999999995, '\n  (...)  \n'),
 Text(1674.0, 815.4000000000001, 'Humidity3pm <= 0.512\ngini = 0.438\nsamples = 13166\nvalue = [8901, 4265]'),
 Text(1395.0, 271.79999999999995, '\n  (...)  \n'),
 Text(1953.0, 271.79999999999995, '\n  (...)  \n'),
 Text(3348.0, 1359.0, 'Humidity3pm <= 0.825\ngini = 0.47\nsamples = 16570\nvalue = [6266, 10304]'),
 Text(2790.0, 815.4000000000001, 'WindGustSpeed <= 0.279\ngini = 0.499\nsamples = 9136\nvalue = [4804, 4332]'),
 Text(2511.0, 271.79999999999995, '\n  (...)  \n'),
 Text(3069.0, 271.79999999999995, '\n  (...)  \n'),
 Text(3906.0, 815.4000000000001, 'Rainfall <= 0.01\ngini = 0.316\nsamples = 7434\nvalue = [1462, 5972]'),
 Text(3627.0, 271.79999999999995, '\n  (...)  \n'),
 Text(4185.0, 271.79999999999995, '\n  (...)  \n')]

How a Decision Tree is Created

Note the gini value in each box. This is the loss function used by the decision tree to decide which column should be used for splitting the data, and at what point the column should be split. A lower Gini index indicates a better split. A perfect split (only one class on each side) has a Gini index of 0.

For a mathematical discussion of the Gini Index, watch this video: It has the following formula:

dt2

Conceptually speaking, while training the models evaluates all possible splits across all possible columns and picks the best one. Then, it recursively performs an optimal split for the two portions. In practice, however, it's very inefficient to check all possible splits, so the model uses a heuristic (predefined strategy) combined with some randomization.

Let's check the depth of the tree that was created.

model.tree_.max_depth
48
tree_text = export_text(model, max_depth=10, feature_names=list(X_train.columns))
print(tree_text[:5000])
|--- Humidity3pm <= 0.72
|   |--- Rainfall <= 0.00
|   |   |--- Sunshine <= 0.52
|   |   |   |--- Pressure3pm <= 0.58
|   |   |   |   |--- WindGustSpeed <= 0.36
|   |   |   |   |   |--- Humidity3pm <= 0.28
|   |   |   |   |   |   |--- WindDir9am_NE <= 0.50
|   |   |   |   |   |   |   |--- Location_Watsonia <= 0.50
|   |   |   |   |   |   |   |   |--- Cloud9am <= 0.83
|   |   |   |   |   |   |   |   |   |--- WindSpeed3pm <= 0.07
|   |   |   |   |   |   |   |   |   |   |--- Pressure3pm <= 0.46
|   |   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |   |--- Pressure3pm >  0.46
|   |   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |   |   |--- WindSpeed3pm >  0.07
|   |   |   |   |   |   |   |   |   |   |--- MinTemp <= 0.32
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |   |   |   |   |   |--- MinTemp >  0.32
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |   |   |   |   |   |--- Cloud9am >  0.83
|   |   |   |   |   |   |   |   |   |--- Cloud3pm <= 0.42
|   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |--- Cloud3pm >  0.42
|   |   |   |   |   |   |   |   |   |   |--- Rainfall <= 0.00
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 2
|   |   |   |   |   |   |   |   |   |   |--- Rainfall >  0.00
|   |   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |--- Location_Watsonia >  0.50
|   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |--- WindDir9am_NE >  0.50
|   |   |   |   |   |   |   |--- WindGustSpeed <= 0.25
|   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |--- WindGustSpeed >  0.25
|   |   |   |   |   |   |   |   |--- Pressure9am <= 0.54
|   |   |   |   |   |   |   |   |   |--- Evaporation <= 0.09
|   |   |   |   |   |   |   |   |   |   |--- Location_AliceSprings <= 0.50
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |   |   |   |   |   |--- Location_AliceSprings >  0.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |--- Evaporation >  0.09
|   |   |   |   |   |   |   |   |   |   |--- WindGustDir_ENE <= 0.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |   |--- WindGustDir_ENE >  0.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |   |--- Pressure9am >  0.54
|   |   |   |   |   |   |   |   |   |--- Humidity3pm <= 0.20
|   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |--- Humidity3pm >  0.20
|   |   |   |   |   |   |   |   |   |   |--- Evaporation <= 0.02
|   |   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |   |--- Evaporation >  0.02
|   |   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |--- Humidity3pm >  0.28
|   |   |   |   |   |   |--- Sunshine <= 0.05
|   |   |   |   |   |   |   |--- WindGustSpeed <= 0.25
|   |   |   |   |   |   |   |   |--- Evaporation <= 0.01
|   |   |   |   |   |   |   |   |   |--- WindGustSpeed <= 0.23
|   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |--- WindGustSpeed >  0.23
|   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |   |--- Evaporation >  0.01
|   |   |   |   |   |   |   |   |   |--- Evaporation <= 0.07
|   |   |   |   |   |   |   |   |   |   |--- Temp3pm <= 0.34
|   |   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |   |--- Temp3pm >  0.34
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 11
|   |   |   |   |   |   |   |   |   |--- Evaporation >  0.07
|   |   |   |   |   |   |   |   |   |   |--- WindSpeed9am <= 0.12
|   |   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |   |--- WindSpeed9am >  0.12
|   |   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |--- WindGustSpeed >  0.25
|   |   |   |   |   |   |   |   |--- Pressure9am <= 0.56
|   |   |   |   |   |   |   |   |   |--- MinTemp <= 0.40
|   |   |   |   |   |   |   |   |   |   |--- WindDir9am_WNW <= 0.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |   |   |--- WindDir9am_WNW >  0.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |   |   |--- MinTemp >  0.40
|   |   |   |   |   |   |   |   |   |   |--- Humidity3pm <= 0.66
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 7
|   |   |   |   |   |   |   |   |   |   |--- Humidity3pm >  0.66
|   |   |   |   |   |   |   |   |   |   |   |--- truncated branch of depth 4
|   |   |   |   |   |   |   |   |--- Pressure9am >  0.56
|   |   |   |   |   

Feature Importance

X_train.columns
Index(['MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
       'WindGustSpeed', 'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am',
       'Humidity3pm',
       ...
       'WindDir3pm_SSE', 'WindDir3pm_SSW', 'WindDir3pm_SW', 'WindDir3pm_W',
       'WindDir3pm_WNW', 'WindDir3pm_WSW', 'WindDir3pm_nan', 'RainToday_No',
       'RainToday_Yes', 'RainToday_nan'],
      dtype='object', length=119)
model.feature_importances_
array([3.48942086e-02, 3.23605486e-02, 5.91385668e-02, 2.49619907e-02,
       4.94652143e-02, 5.63334673e-02, 2.80205998e-02, 2.98128801e-02,
       4.02182908e-02, 2.61441297e-01, 3.44145027e-02, 6.20573699e-02,
       1.36406176e-02, 1.69229866e-02, 3.50001550e-02, 3.04064076e-02,
       2.24086587e-03, 2.08018104e-03, 1.27475954e-03, 7.26936324e-04,
       1.39779517e-03, 1.15264873e-03, 6.92808159e-04, 1.80675598e-03,
       1.08370901e-03, 1.19773895e-03, 8.87119103e-04, 2.15764220e-03,
       1.67094731e-03, 7.98919493e-05, 1.10558668e-03, 1.42008656e-03,
       4.10087635e-04, 1.09028115e-03, 1.44164766e-03, 9.08284767e-04,
       1.05770304e-03, 6.18133455e-04, 1.80387272e-03, 2.10403527e-03,
       2.74413333e-04, 7.31599405e-04, 1.35408990e-03, 1.54759332e-03,
       1.30917564e-03, 1.07134670e-03, 8.36408023e-04, 1.62662229e-03,
       1.00326116e-03, 2.16053455e-03, 8.46802258e-04, 1.88919081e-03,
       9.29325203e-04, 1.29545157e-03, 1.27604831e-03, 5.12736888e-04,
       1.38458902e-03, 3.97103931e-04, 1.03734689e-03, 1.44437047e-03,
       1.75870184e-03, 1.42487857e-03, 2.78109569e-03, 2.00782698e-03,
       2.80617652e-04, 1.61509734e-03, 1.64361598e-03, 2.36124112e-03,
       3.05457932e-03, 2.33239534e-03, 2.78643875e-03, 2.16695261e-03,
       3.41491352e-03, 2.30573542e-03, 2.28270604e-03, 2.34408118e-03,
       2.26557332e-03, 2.54592702e-03, 2.75264499e-03, 2.83905192e-03,
       2.49480561e-03, 1.54840338e-03, 2.50305095e-03, 2.53945388e-03,
       2.28130055e-03, 3.80572180e-03, 2.58535069e-03, 3.10172224e-03,
       2.54236791e-03, 2.50297796e-03, 2.06400988e-03, 2.52931192e-03,
       2.07840517e-03, 1.77387278e-03, 1.78920555e-03, 2.77709687e-03,
       2.42564566e-03, 2.26471887e-03, 1.73346117e-03, 2.23926957e-03,
       2.47865244e-03, 2.31917387e-03, 3.21211861e-03, 2.92382975e-03,
       2.24399274e-03, 3.68774754e-03, 3.87595982e-03, 3.20326068e-03,
       2.53323550e-03, 2.40444844e-03, 2.26790411e-03, 2.19744009e-03,
       2.28064147e-03, 2.88545323e-03, 2.05278867e-03, 1.12604304e-03,
       2.86325849e-04, 1.32322128e-03, 1.72690480e-03])
importance_df = pd.DataFrame({
    'feature': X_train.columns,
    'importance': model.feature_importances_
}).sort_values('importance', ascending=False)
importance_df.head(10)
feature importance
9 Humidity3pm 0.261441
11 Pressure3pm 0.062057
2 Rainfall 0.059139
5 WindGustSpeed 0.056333
4 Sunshine 0.049465
8 Humidity9am 0.040218
14 Temp9am 0.035000
0 MinTemp 0.034894
10 Pressure9am 0.034415
1 MaxTemp 0.032361
plt.title('Feature Importance')
sns.barplot(data=importance_df.head(10), x='importance', y='feature');

Hyperparameter Tuning and Overfitting

?DecisionTreeClassifier

As we saw in the previous section, our decision tree classifier memorized all training examples, leading to a 100% training accuracy, while the validation accuracy was only marginally better than a dumb baseline model. This phenomenon is called overfitting, and in this section, we'll look at some strategies for reducing overfitting. The process of reducing overfitting is known as regularlization.

The DecisionTreeClassifier accepts several arguments, some of which can be modified to reduce overfitting.

These arguments are called hyperparameters because they must be configured manually (as opposed to the parameters within the model which are learned from the data. We'll explore a couple of hyperparameters:

  • max_depth
  • max_leaf_nodes

max_depth

By reducing the maximum depth of the decision tree, we can prevent the tree from memorizing all training examples, which may lead to better generalization

model = DecisionTreeClassifier(max_depth=3, random_state=42)
model.fit(X_train, train_targets)
DecisionTreeClassifier(max_depth=3, random_state=42)
model.score(X_train, train_targets)
0.8291308037337859
model.score(X_val, val_targets)
0.8334397307178921
model.classes_
array(['No', 'Yes'], dtype=object)

Great, while the training accuracy of the model has gone down, the validation accuracy of the model has increased significantly.

plt.figure(figsize=(80, 40))
plot_tree(model, feature_names=X_train.columns, filled=True, rounded=True, class_names=model.classes_)
[Text(2232.0, 1902.6000000000001, 'Humidity3pm <= 0.715\ngini = 0.349\nsamples = 98988\nvalue = [76705, 22283]\nclass = No'),
 Text(1116.0, 1359.0, 'Rainfall <= 0.004\ngini = 0.248\nsamples = 82418\nvalue = [70439, 11979]\nclass = No'),
 Text(558.0, 815.4000000000001, 'Sunshine <= 0.525\ngini = 0.198\nsamples = 69252\nvalue = [61538, 7714]\nclass = No'),
 Text(279.0, 271.79999999999995, 'gini = 0.363\nsamples = 12620\nvalue = [9618, 3002]\nclass = No'),
 Text(837.0, 271.79999999999995, 'gini = 0.153\nsamples = 56632\nvalue = [51920, 4712]\nclass = No'),
 Text(1674.0, 815.4000000000001, 'Humidity3pm <= 0.512\ngini = 0.438\nsamples = 13166\nvalue = [8901, 4265]\nclass = No'),
 Text(1395.0, 271.79999999999995, 'gini = 0.293\nsamples = 4299\nvalue = [3531, 768]\nclass = No'),
 Text(1953.0, 271.79999999999995, 'gini = 0.478\nsamples = 8867\nvalue = [5370, 3497]\nclass = No'),
 Text(3348.0, 1359.0, 'Humidity3pm <= 0.825\ngini = 0.47\nsamples = 16570\nvalue = [6266, 10304]\nclass = Yes'),
 Text(2790.0, 815.4000000000001, 'WindGustSpeed <= 0.279\ngini = 0.499\nsamples = 9136\nvalue = [4804, 4332]\nclass = No'),
 Text(2511.0, 271.79999999999995, 'gini = 0.472\nsamples = 5583\nvalue = [3457, 2126]\nclass = No'),
 Text(3069.0, 271.79999999999995, 'gini = 0.471\nsamples = 3553\nvalue = [1347, 2206]\nclass = Yes'),
 Text(3906.0, 815.4000000000001, 'Rainfall <= 0.01\ngini = 0.316\nsamples = 7434\nvalue = [1462, 5972]\nclass = Yes'),
 Text(3627.0, 271.79999999999995, 'gini = 0.391\nsamples = 4360\nvalue = [1161, 3199]\nclass = Yes'),
 Text(4185.0, 271.79999999999995, 'gini = 0.177\nsamples = 3074\nvalue = [301, 2773]\nclass = Yes')]
print(export_text(model, feature_names=list(X_train.columns)))
|--- Humidity3pm <= 0.72
|   |--- Rainfall <= 0.00
|   |   |--- Sunshine <= 0.52
|   |   |   |--- class: No
|   |   |--- Sunshine >  0.52
|   |   |   |--- class: No
|   |--- Rainfall >  0.00
|   |   |--- Humidity3pm <= 0.51
|   |   |   |--- class: No
|   |   |--- Humidity3pm >  0.51
|   |   |   |--- class: No
|--- Humidity3pm >  0.72
|   |--- Humidity3pm <= 0.82
|   |   |--- WindGustSpeed <= 0.28
|   |   |   |--- class: No
|   |   |--- WindGustSpeed >  0.28
|   |   |   |--- class: Yes
|   |--- Humidity3pm >  0.82
|   |   |--- Rainfall <= 0.01
|   |   |   |--- class: Yes
|   |   |--- Rainfall >  0.01
|   |   |   |--- class: Yes

def max_depth_error(md):
    model = DecisionTreeClassifier(max_depth=md, random_state=42)
    model.fit(X_train, train_targets)
    train_error = 1 - model.score(X_train, train_targets)
    val_error = 1 - model.score(X_val, val_targets)
    return {'Max Depth': md, 'Training Error': train_error, 'Validation Error': val_error}
%%time
errors_df = pd.DataFrame([max_depth_error(md) for md in range(1, 21)])
CPU times: user 30.3 s, sys: 203 ms, total: 30.5 s
Wall time: 30.5 s
errors_df
Max Depth Training Error Validation Error
0 1 0.184315 0.177935
1 2 0.179547 0.172712
2 3 0.170869 0.166560
3 4 0.165707 0.164355
4 5 0.160676 0.159074
5 6 0.156271 0.157275
6 7 0.153312 0.154605
7 8 0.147806 0.158029
8 9 0.140906 0.156578
9 10 0.132945 0.157333
10 11 0.123227 0.159248
11 12 0.113489 0.160815
12 13 0.101750 0.163833
13 14 0.089981 0.167373
14 15 0.078999 0.171261
15 16 0.068180 0.174279
16 17 0.058138 0.176890
17 18 0.048733 0.181243
18 19 0.040025 0.187569
19 20 0.032539 0.190297
plt.figure()
plt.plot(errors_df['Max Depth'], errors_df['Training Error'])
plt.plot(errors_df['Max Depth'], errors_df['Validation Error'])
plt.title("Training vs Validation Error")
plt.xticks(range(0,21,2))
plt.xlabel('Max. Depth')
plt.ylabel('Prediction Error ie 1-Accuracy')
plt.legend(['Training', 'Validation'])
<matplotlib.legend.Legend at 0x7f037824ffa0>

overfitting

So for us max depth of 7 results in lowest validation error

model = DecisionTreeClassifier(max_depth=7, random_state=42).fit(X_train, train_targets)
model.score(X_val, val_targets), model.score(X_train, train_targets)
(0.8453949277465034, 0.8466884874934335)

max_leaf_nodes

Another way to control the size of complexity of a decision tree is to limit the number of leaf nodes. This allows branches of the tree to have varying depths.

model = DecisionTreeClassifier(max_leaf_nodes = 128, random_state = 42)
model.fit(X_train, train_targets)
DecisionTreeClassifier(max_leaf_nodes=128, random_state=42)
model.score(X_train, train_targets)
0.8480421869317493
model.score(X_val, val_targets)
0.8442342290058615
model.tree_.max_depth
12

Notice that the model was able to achieve a greater depth of 12 for certain paths while keeping other paths shorter.

model_text = export_text(model, feature_names = list(X_train.columns))
print(model_text[:3000])
|--- Humidity3pm <= 0.72
|   |--- Rainfall <= 0.00
|   |   |--- Sunshine <= 0.52
|   |   |   |--- Pressure3pm <= 0.58
|   |   |   |   |--- WindGustSpeed <= 0.36
|   |   |   |   |   |--- Humidity3pm <= 0.28
|   |   |   |   |   |   |--- class: No
|   |   |   |   |   |--- Humidity3pm >  0.28
|   |   |   |   |   |   |--- Sunshine <= 0.05
|   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |--- Sunshine >  0.05
|   |   |   |   |   |   |   |--- Pressure3pm <= 0.43
|   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |--- Pressure3pm >  0.43
|   |   |   |   |   |   |   |   |--- Humidity3pm <= 0.57
|   |   |   |   |   |   |   |   |   |--- WindDir9am_NE <= 0.50
|   |   |   |   |   |   |   |   |   |   |--- WindDir9am_NNE <= 0.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |   |   |   |--- WindDir9am_NNE >  0.50
|   |   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |   |   |--- WindDir9am_NE >  0.50
|   |   |   |   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |   |   |--- Humidity3pm >  0.57
|   |   |   |   |   |   |   |   |   |--- MaxTemp <= 0.53
|   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |   |   |--- MaxTemp >  0.53
|   |   |   |   |   |   |   |   |   |   |--- Temp3pm <= 0.67
|   |   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |   |   |   |--- Temp3pm >  0.67
|   |   |   |   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |--- WindGustSpeed >  0.36
|   |   |   |   |   |--- Humidity3pm <= 0.45
|   |   |   |   |   |   |--- Sunshine <= 0.39
|   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |--- Sunshine >  0.39
|   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |--- Humidity3pm >  0.45
|   |   |   |   |   |   |--- Pressure3pm <= 0.49
|   |   |   |   |   |   |   |--- class: Yes
|   |   |   |   |   |   |--- Pressure3pm >  0.49
|   |   |   |   |   |   |   |--- class: Yes
|   |   |   |--- Pressure3pm >  0.58
|   |   |   |   |--- Pressure3pm <= 0.70
|   |   |   |   |   |--- Sunshine <= 0.32
|   |   |   |   |   |   |--- WindDir9am_N <= 0.50
|   |   |   |   |   |   |   |--- Humidity3pm <= 0.67
|   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |   |--- Humidity3pm >  0.67
|   |   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |--- WindDir9am_N >  0.50
|   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |--- Sunshine >  0.32
|   |   |   |   |   |   |--- WindGustSpeed <= 0.33
|   |   |   |   |   |   |   |--- class: No
|   |   |   |   |   |   |--- WindGustSpeed >  0.33
|   |   |   |   |   |   |   |--- class: No
|   |   |   |   |--- Pressure3pm >  0.70
|   |   |   |   |   |--- Location_CoffsHarbour <= 0.50
|   |   |   |   |   |   |--- class: No
|   |   |   |   |   |--- Location_CoffsHarbour >  0.50
|   |   |   |   |   |   |--- class: No
|   |   |--- Sunshine >  0.52
|   |   |