Class: NerRuby::SlidingWindow

Inherits:
Object
  • Object
show all
Defined in:
lib/ner_ruby/sliding_window.rb

Constant Summary collapse

DEFAULT_MAX_LENGTH =
512
DEFAULT_STRIDE =
128

Instance Method Summary collapse

Constructor Details

#initialize(max_length: DEFAULT_MAX_LENGTH, stride: DEFAULT_STRIDE) ⇒ SlidingWindow

Returns a new instance of SlidingWindow.



8
9
10
11
# File 'lib/ner_ruby/sliding_window.rb', line 8

def initialize(max_length: DEFAULT_MAX_LENGTH, stride: DEFAULT_STRIDE)
  @max_length = max_length
  @stride = stride
end

Instance Method Details

#merge_entities(window_results) ⇒ Object

Merge entities from overlapping windows, preferring higher scores



35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# File 'lib/ner_ruby/sliding_window.rb', line 35

def merge_entities(window_results)
  all_entities = []

  window_results.each do |entities|
    entities.each do |entity|
      existing = all_entities.find { |e| overlaps?(e, entity) }
      if existing
        # Keep the one with higher score
        if entity.score > existing.score
          all_entities.delete(existing)
          all_entities << entity
        end
      else
        all_entities << entity
      end
    end
  end

  all_entities.sort_by { |e| e.start_offset || 0 }
end

#split(tokens, ids) ⇒ Object

Split tokens into overlapping windows



14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# File 'lib/ner_ruby/sliding_window.rb', line 14

def split(tokens, ids)
  return [{ tokens: tokens, ids: ids, offset: 0 }] if tokens.length <= @max_length

  windows = []
  start = 0

  while start < tokens.length
    window_end = [start + @max_length, tokens.length].min
    windows << {
      tokens: tokens[start...window_end],
      ids: ids[start...window_end],
      offset: start
    }
    break if window_end >= tokens.length
    start += @max_length - @stride
  end

  windows
end