Class: RedditPostClassifierBot::RedditTrainer

Inherits:
Object
  • Object
show all
Defined in:
lib/RedditPostClassifierBot/reddit_trainer.rb

Defined Under Namespace

Classes: Post

Constant Summary collapse

REDDIT_URL =
"https://www.reddit.com"
CLASSES =
{
  hot: "/",
  "new" => "/new/",
  rising: "/rising/",
  controversial: "/controversial/",
  top_hour: "/top/",
  top_day: "/top/?sort=top&t=day",
  top_week: "/top/?sort=top&t=week",
  top_month: "/top/?sort=top&t=month",
  top_year: "/top/?sort=top&t=year"
}

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(trials = 10, per_page = 200, debug = true) ⇒ RedditTrainer

Returns a new instance of RedditTrainer.



22
23
24
25
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 22

def initialize(trials = 10, per_page = 200, debug = true)
  @max_trials, @per_page, @debug = trials, per_page, debug
  @posts, @trials_done = [], 0
end

Instance Attribute Details

#classificationsObject (readonly)

Returns the value of attribute classifications.



16
17
18
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 16

def classifications
  @classifications
end

Class Method Details

.trained_onObject



18
19
20
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 18

def self.trained_on
  CLASSES.KEYS
end

Instance Method Details

#classify(subreddit, title, post) ⇒ Object



46
47
48
49
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 46

def classify(subreddit, title, post)
  @classifications = nbayes.classify subreddit, title, post
  @classifications.max_class
end

#dumpObject



51
52
53
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 51

def dump
  nbayes.dump; self
end

#fetch_and_classify(path = ) ⇒ Object



60
61
62
63
64
65
66
67
68
69
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 60

def fetch_and_classify(path = CLASSES[:front])
  posts = reddit(path)["data"]["children"]
  log "Classifying #{posts.size} posts"

  posts.inject({}) do |h, p|
    post = Post.new p
    classification = classify post.subreddit, post.title, post.body
    h.merge! uri_with_base(post.path).to_s => classification
  end.group_by { |_, c| c }
end

#inspectObject



71
72
73
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 71

def inspect
  "<#{self.class}:#{object_id.to_s(16)} @max_trials=#{@max_trials.inspect} @per_page=#{@per_page.inspect} @debug=#{@debug.inspect} @posts.size=#{@posts.size}>"
end

#loadObject



55
56
57
58
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 55

def load
  train and dump unless nbayes.load
  self
end

#nbayesObject



27
28
29
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 27

def nbayes
  @nbayes ||= RedditPostClassifierBot::NBayesClassifier.new
end

#train(classes = CLASSES) ⇒ Object



31
32
33
34
35
36
37
38
39
40
41
42
43
44
# File 'lib/RedditPostClassifierBot/reddit_trainer.rb', line 31

def train(classes = CLASSES)
  classes.each do |classification, path|
    log "training on #{classification} posts, page #{@trials_done} of #{@max_trials}"

    reddit(path)["data"]["children"].each do |p|
      @posts << (post = Post.new p)
      nbayes.train post.serialize, classification
    end
  end

  @trials_done += 1
  recurse_to_next_page CLASSES, @posts.last if @trials_done <= @max_trials
  self
end