Class: Ai4r::Classifiers::ID3
- Inherits:
-
Classifier
- Object
- Classifier
- Ai4r::Classifiers::ID3
- Defined in:
- lib/ai4r/classifiers/id3.rb
Overview
Introduction
This is an implementation of the ID3 algorithm (Quinlan) Given a set of preclassified examples, it builds a top-down induction of decision tree, biased by the information gain and entropy measure.
How to use it
DATA_LABELS = [ 'city', 'age_range', 'gender', 'marketing_target' ]
DATA_ITEMS = [
['New York', '<30', 'M', 'Y'],
['Chicago', '<30', 'M', 'Y'],
['Chicago', '<30', 'F', 'Y'],
['New York', '<30', 'M', 'Y'],
['New York', '<30', 'M', 'Y'],
['Chicago', '[30-50)', 'M', 'Y'],
['New York', '[30-50)', 'F', 'N'],
['Chicago', '[30-50)', 'F', 'Y'],
['New York', '[30-50)', 'F', 'N'],
['Chicago', '[50-80]', 'M', 'N'],
['New York', '[50-80]', 'F', 'N'],
['New York', '[50-80]', 'M', 'N'],
['Chicago', '[50-80]', 'M', 'N'],
['New York', '[50-80]', 'F', 'N'],
['Chicago', '>80', 'F', 'Y']
]
data_set = DataSet.new(:data_items=>DATA_SET, :data_labels=>DATA_LABELS)
id3 = Ai4r::Classifiers::ID3.new.build(data_set)
id3.get_rules
# => if age_range=='<30' then marketing_target='Y'
elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'
elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'
elsif age_range=='[50-80]' then marketing_target='N'
elsif age_range=='>80' then marketing_target='Y'
else
raise 'There was not enough information during training to do '
'a proper induction for this data element'
end
id3.eval(['New York', '<30', 'M'])
# => 'Y'
A better way to load the data
In the real life you will use lot more data training examples, with more attributes. Consider moving your data to an external CSV (comma separate values) file.
data_file = "#{File.dirname(__FILE__)}/data_set.csv"
data_set = DataSet.load_csv_with_labels data_file
id3 = Ai4r::Classifiers::ID3.new.build(data_set)
A nice tip for data evaluation
id3 = Ai4r::Classifiers::ID3.new.build(data_set)
age_range = '<30'
marketing_target = nil
eval id3.get_rules
puts marketing_target
# => 'Y'
More about ID3 and decision trees
About the project
- Author
-
Sergio Fierens
- License
-
MPL 1.1
- Url
Instance Attribute Summary collapse
-
#data_set ⇒ Object
readonly
Returns the value of attribute data_set.
-
#majority_class ⇒ Object
readonly
Returns the value of attribute majority_class.
-
#validation_set ⇒ Object
readonly
Returns the value of attribute validation_set.
Class Method Summary collapse
Instance Method Summary collapse
-
#build(data_set, options = {}) ⇒ Object
Create a new ID3 classifier.
- #build_node(data_examples, flag_att = [], depth = 0) ⇒ Object
-
#eval(data) ⇒ Object
You can evaluate new data, predicting its category.
-
#get_rules ⇒ Object
This method returns the generated rules in ruby code.
- #initialize ⇒ Object constructor
- #preprocess_data(data_examples) ⇒ Object
-
#prune! ⇒ Object
Prune the decision tree using the validation set provided during build.
-
#to_graphviz ⇒ Object
Generate GraphViz DOT syntax describing the decision tree.
-
#to_h ⇒ Object
Return a nested Hash representation of the decision tree.
Methods included from Data::Parameterizable
#get_parameters, included, #set_parameters
Constructor Details
#initialize ⇒ Object
103 104 105 106 107 108 |
# File 'lib/ai4r/classifiers/id3.rb', line 103 def initialize super() @max_depth = nil @min_gain = 0 @on_unknown = :raise end |
Instance Attribute Details
#data_set ⇒ Object (readonly)
Returns the value of attribute data_set.
96 97 98 |
# File 'lib/ai4r/classifiers/id3.rb', line 96 def data_set @data_set end |
#majority_class ⇒ Object (readonly)
Returns the value of attribute majority_class.
96 97 98 |
# File 'lib/ai4r/classifiers/id3.rb', line 96 def majority_class @majority_class end |
#validation_set ⇒ Object (readonly)
Returns the value of attribute validation_set.
96 97 98 |
# File 'lib/ai4r/classifiers/id3.rb', line 96 def validation_set @validation_set end |
Class Method Details
.log2(z) ⇒ Object
285 286 287 288 289 |
# File 'lib/ai4r/classifiers/id3.rb', line 285 def self.log2(z) return 0.0 if z.zero? Math.log(z) / LOG2 end |
.sum(values) ⇒ Object
279 280 281 |
# File 'lib/ai4r/classifiers/id3.rb', line 279 def self.sum(values) values.sum end |
Instance Method Details
#build(data_set, options = {}) ⇒ Object
Create a new ID3 classifier. You must provide a DataSet instance as parameter. The last attribute of each item is considered as the item class.
116 117 118 119 120 121 122 123 |
# File 'lib/ai4r/classifiers/id3.rb', line 116 def build(data_set, = {}) data_set.check_not_empty @data_set = data_set @validation_set = [:validation_set] preprocess_data(@data_set.data_items) prune! if @validation_set self end |
#build_node(data_examples, flag_att = [], depth = 0) ⇒ Object
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
# File 'lib/ai4r/classifiers/id3.rb', line 209 def build_node(data_examples, flag_att = [], depth = 0) return ErrorNode.new if data_examples.empty? domain = domain(data_examples) return CategoryNode.new(@data_set.category_label, domain.last[0]) if domain.last.length == 1 if flag_att.length >= domain.length - 1 return CategoryNode.new(@data_set.category_label, most_freq(data_examples, domain)) end return CategoryNode.new(@data_set.category_label, most_freq(data_examples, domain)) if @max_depth && depth >= @max_depth best_index = nil best_entropy = nil best_split = nil best_threshold = nil numeric = false domain[0..-2].each_index do |index| next if flag_att.include?(index) if domain[index].all? { |v| v.is_a? Numeric } threshold, split, entropy = best_numeric_split(data_examples, index, domain) if best_entropy.nil? || entropy < best_entropy best_entropy = entropy best_index = index best_split = split best_threshold = threshold numeric = true end else freq_grid = freq_grid(index, data_examples, domain) entropy = entropy(freq_grid, data_examples.length) if best_entropy.nil? || entropy < best_entropy best_entropy = entropy best_index = index best_split = split_data_examples(data_examples, domain, index) numeric = false end end end gain = information_gain(data_examples, domain, best_index) if gain < @min_gain return CategoryNode.new(@data_set.category_label, most_freq(data_examples, domain)) end if best_split.length == 1 return CategoryNode.new(@data_set.category_label, most_freq(data_examples, domain)) end nodes = best_split.collect do |partial_data_examples| build_node(partial_data_examples, numeric ? flag_att : [*flag_att, best_index], depth + 1) end majority = most_freq(data_examples, domain) if numeric EvaluationNode.new(@data_set.data_labels, best_index, best_threshold, nodes, true, majority) else EvaluationNode.new(@data_set.data_labels, best_index, domain[best_index], nodes, false, majority) end end |
#eval(data) ⇒ Object
You can evaluate new data, predicting its category. e.g.
id3.eval(['New York', '<30', 'F']) # => 'Y'
130 131 132 |
# File 'lib/ai4r/classifiers/id3.rb', line 130 def eval(data) @tree&.value(data, self) end |
#get_rules ⇒ Object
This method returns the generated rules in ruby code. e.g.
id3.get_rules
# => if age_range=='<30' then marketing_target='Y'
elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'
elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'
elsif age_range=='[50-80]' then marketing_target='N'
elsif age_range=='>80' then marketing_target='Y'
else
raise 'There was not enough information during training to do '
'a proper induction for this data element'
end
It is a nice way to inspect induction results, and also to execute them:
age_range = '<30'
marketing_target = nil
eval id3.get_rules
puts marketing_target
# => 'Y'
155 156 157 158 159 160 161 162 163 |
# File 'lib/ai4r/classifiers/id3.rb', line 155 def get_rules # return "Empty ID3 tree" if !@tree rules = @tree.get_rules rules = rules.collect do |rule| "#{rule[0..-2].join(' and ')} then #{rule.last}" end error_msg = 'There was not enough information during training to do a proper induction for this data element' "if #{rules.join("\nelsif ")}\nelse raise '#{error_msg}' end" end |
#preprocess_data(data_examples) ⇒ Object
200 201 202 203 |
# File 'lib/ai4r/classifiers/id3.rb', line 200 def preprocess_data(data_examples) @majority_class = most_freq(data_examples, domain(data_examples)) @tree = build_node(data_examples, [], 0) end |
#prune! ⇒ Object
Prune the decision tree using the validation set provided during build. Subtrees are replaced by a single leaf when this increases the classification accuracy on the validation data.
191 192 193 194 195 196 |
# File 'lib/ai4r/classifiers/id3.rb', line 191 def prune! return self unless @validation_set @tree = prune_node(@tree, @validation_set.data_items) self end |
#to_graphviz ⇒ Object
Generate GraphViz DOT syntax describing the decision tree. Nodes are labeled with attribute names or category values and edges are labeled with attribute values.
178 179 180 181 182 183 184 185 |
# File 'lib/ai4r/classifiers/id3.rb', line 178 def to_graphviz return 'digraph G {}' unless @tree lines = ['digraph G {'] @tree.to_graphviz(0, lines) lines << '}' lines.join("\n") end |
#to_h ⇒ Object
Return a nested Hash representation of the decision tree. This structure can easily be converted to JSON or other formats. Leaf nodes are represented by their category value, while internal nodes are hashes keyed by attribute value.
170 171 172 |
# File 'lib/ai4r/classifiers/id3.rb', line 170 def to_h @tree&.to_h end |