Method: Rumale::EvaluationMeasure::ROCAUC#roc_curve

Defined in:
lib/rumale/evaluation_measure/roc_auc.rb

#roc_curve(y_true, y_score, pos_label = nil) ⇒ Array

Calculate receiver operation characteristic curve.

Parameters:

  • y_true (Numo::Int32)

    (shape: [n_samples]) Ground truth binary labels.

  • y_score (Numo::DFloat)

    (shape: [n_samples]) Predicted class probabilities or confidence scores.

  • pos_label (Integer) (defaults to: nil)

    Label to be a positive label when binarizing the given labels. If nil is given, the method considers the maximum value of the label as a positive label.

Returns:

  • (Array)

    fpr (Numo::DFloat): false positive rates. tpr (Numo::DFloat): true positive rates. thresholds (Numo::DFloat): thresholds on the decision function used to calculate fpr and tpr.

Raises:

  • (ArgumentError)


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# File 'lib/rumale/evaluation_measure/roc_auc.rb', line 61

def roc_curve(y_true, y_score, pos_label = nil)
  check_params_type(Numo::NArray, y_true: y_true, y_score: y_score)
  raise ArgumentError, 'Expect y_true to be 1-D arrray.' unless y_true.shape[1].nil?
  raise ArgumentError, 'Expect y_score to be 1-D arrray.' unless y_score.shape[1].nil?
  labels = y_true.to_a.uniq
  if pos_label.nil?
    raise ArgumentError, 'y_true must be binary labels or pos_label must be specified if y_true is multi-label' unless labels.size == 2
  else
    raise ArgumentError, 'y_true must have elements whose values are pos_label.' unless y_true.to_a.uniq.include?(pos_label)
  end

  false_pos, true_pos, thresholds = binary_roc_curve(y_true, y_score, pos_label)

  if true_pos.size.zero? || false_pos[0] != 0 || true_pos[0] != 0
    true_pos = true_pos.insert(0, 0)
    false_pos = false_pos.insert(0, 0)
    thresholds = thresholds.insert(0, thresholds[0] + 1)
  end

  tpr = true_pos / true_pos[-1].to_f
  fpr = false_pos / false_pos[-1].to_f

  [fpr, tpr, thresholds]
end