Class: DecisionTree::ID3Tree
- Inherits:
-
Object
- Object
- DecisionTree::ID3Tree
- Defined in:
- lib/decisiontree/id3_tree.rb
Defined Under Namespace
Classes: Node
Instance Method Summary collapse
- #graph(filename) ⇒ Object
-
#id3_continuous(data, attributes, attribute) ⇒ Object
ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds).
-
#id3_discrete(data, attributes, attribute) ⇒ Object
ID3 for discrete label cases.
-
#initialize(attributes, data, default, type) ⇒ ID3Tree
constructor
A new instance of ID3Tree.
- #predict(test) ⇒ Object
- #train(data = @data, attributes = @attributes, default = @default) ⇒ Object
Constructor Details
#initialize(attributes, data, default, type) ⇒ ID3Tree
Returns a new instance of ID3Tree.
34 35 36 37 |
# File 'lib/decisiontree/id3_tree.rb', line 34 def initialize(attributes, data, default, type) @used, @tree, @type = {}, {}, type @data, @attributes, @default = data, attributes, default end |
Instance Method Details
#graph(filename) ⇒ Object
102 103 104 105 |
# File 'lib/decisiontree/id3_tree.rb', line 102 def graph(filename) dgp = DotGraphPrinter.new(build_tree) dgp.write_to_file("#{filename}.png", "png") end |
#id3_continuous(data, attributes, attribute) ⇒ Object
ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds)
75 76 77 78 79 80 81 82 83 84 85 86 87 |
# File 'lib/decisiontree/id3_tree.rb', line 75 def id3_continuous(data, attributes, attribute) values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, [] values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) } thresholds -= @used[attribute] if @used.has_key? attribute gain = thresholds.collect { |threshold| sp = data.partition { |d| d[attributes.index(attribute)] > threshold } pos = (sp[0].size).to_f / data.size neg = (sp[1].size).to_f / data.size [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold] }.max { |a,b| a[0] <=> b[0] } end |
#id3_discrete(data, attributes, attribute) ⇒ Object
ID3 for discrete label cases
90 91 92 93 94 95 96 |
# File 'lib/decisiontree/id3_tree.rb', line 90 def id3_discrete(data, attributes, attribute) values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } } remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i } [data.classification.entropy - remainder, attributes.index(attribute)] end |
#predict(test) ⇒ Object
98 99 100 |
# File 'lib/decisiontree/id3_tree.rb', line 98 def predict(test) @type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test) end |
#train(data = @data, attributes = @attributes, default = @default) ⇒ Object
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
# File 'lib/decisiontree/id3_tree.rb', line 39 def train(data=@data, attributes=@attributes, default=@default) # Choose a fitness algorithm case @type when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)} when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)} end return default if data.empty? # return classification if all examples have the same classification return data.first.last if data.classification.uniq.size == 1 # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute) performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) } max = performance.max { |a,b| a[0] <=> b[0] } best = Node.new(attributes[performance.index(max)], max[1], max[0]) @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] tree, l = {best => {}}, ['>=', '<'] case @type when :continuous data.partition { |d| d[attributes.index(best.attribute)] > best.threshold }.each_with_index { |examples, i| tree[best][String.new(l[i])] = train(examples, attributes, (data.classification.mode rescue 0), &fitness) } when :discrete values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } } partitions.each_with_index { |examples, i| tree[best][values[i]] = train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness) } end @tree = tree end |