Module: TorchText::Data::Metrics

Defined in:
lib/torchtext/data/metrics.rb

Class Method Summary collapse

Class Method Details

.bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4) ⇒ Object



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# File 'lib/torchtext/data/metrics.rb', line 5

def bleu_score(candidate_corpus, references_corpus, max_n: 4, weights: [0.25] * 4)
  unless max_n == weights.length
    raise "Length of the \"weights\" list has be equal to max_n"
  end
  unless candidate_corpus.length == references_corpus.length
    raise "The length of candidate and reference corpus should be the same"
  end

  clipped_counts = Torch.zeros(max_n)
  total_counts = Torch.zeros(max_n)
  weights = Torch.tensor(weights)

  candidate_len = 0.0
  refs_len = 0.0

  candidate_corpus.zip(references_corpus) do |candidate, refs|
    candidate_len += candidate.length

    # Get the length of the reference that's closest in length to the candidate
    refs_len_list = refs.map { |ref| ref.length.to_f }
    refs_len += refs_len_list.min_by { |x| (candidate.length - x).abs }

    reference_counters = compute_ngram_counter(refs[0], max_n)
    refs[1..-1].each do |ref|
      reference_counters = reference_counters.merge(compute_ngram_counter(ref, max_n)) { |_, v1, v2| v1 > v2 ? v1 : v2 }
    end

    candidate_counter = compute_ngram_counter(candidate, max_n)

    shared_keys = candidate_counter.keys & reference_counters.keys
    clipped_counter = candidate_counter.slice(*shared_keys).merge(reference_counters.slice(*shared_keys)) { |_, v1, v2| v1 < v2 ? v1 : v2 }

    clipped_counter.each_key do |ngram|
      clipped_counts[ngram.length - 1] += clipped_counter[ngram]
    end

    candidate_counter.each_key do |ngram|
      total_counts[ngram.length - 1] += candidate_counter[ngram]
    end
  end

  if clipped_counts.to_a.min == 0
    0.0
  else
    pn = clipped_counts / total_counts
    log_pn = weights * Torch.log(pn)
    score = Torch.exp(log_pn.sum)

    bp = Math.exp([1 - refs_len / candidate_len, 0].min)

    bp * score.item
  end
end