Class: Wapiti::Model

Inherits:
Object
  • Object
show all
Defined in:
lib/wapiti/model.rb,
ext/wapiti/native.c

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(*args) ⇒ Object



634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
# File 'ext/wapiti/native.c', line 634

static VALUE initialize_model(int argc, VALUE *argv, VALUE self) {
  VALUE options;

  if (argc > 1) {
    rb_raise(cArgumentError,
      "wrong number of arguments (%d for 0..1)", argc);
  }

  if (argc) {
    if (TYPE(argv[0]) == T_HASH) {
      options = rb_funcall(cOptions, rb_intern("new"), 1, argv[0]);
    } else {
      if (strncmp("Wapiti::Options", rb_obj_classname(argv[0]), 15) != 0) {
        rb_raise(cArgumentError, "argument must be a hash or an options instance");
      }
      options = argv[0];
    }
  } else {
    options = rb_funcall(cOptions, rb_intern("new"), 0);
  }

  // yield options if block_given?
  if (rb_block_given_p()) {
    rb_yield(options);
  }

  model_set_options(self, options);

  // Load a previous model if specified by options
  if (get_options(options)->model) {
    rb_funcall(self, rb_intern("load"), 0);
  }

  // initialize counters
  rb_funcall(self, rb_intern("reset_counters"), 0);

  return self;
}

Instance Attribute Details

#optionsObject (readonly)

#pathObject

Returns the value of attribute path.



20
21
22
# File 'lib/wapiti/model.rb', line 20

def path
  @path
end

#sequence_countObject (readonly)

Returns the value of attribute sequence_count.



21
22
23
# File 'lib/wapiti/model.rb', line 21

def sequence_count
  @sequence_count
end

#sequence_errorsObject (readonly)

Returns the value of attribute sequence_errors.



21
22
23
# File 'lib/wapiti/model.rb', line 21

def sequence_errors
  @sequence_errors
end

#token_countObject (readonly)

Returns the value of attribute token_count.



21
22
23
# File 'lib/wapiti/model.rb', line 21

def token_count
  @token_count
end

#token_errorsObject (readonly)

Returns the value of attribute token_errors.



21
22
23
# File 'lib/wapiti/model.rb', line 21

def token_errors
  @token_errors
end

Class Method Details

.load(filename) ⇒ Object



13
14
15
16
17
# File 'lib/wapiti/model.rb', line 13

def load(filename)
  model = new
  model.path = filename
  model.load
end

.train(training_data, options = {}, &block) ⇒ Object



4
5
6
7
8
9
10
11
# File 'lib/wapiti/model.rb', line 4

def train(training_data, options = {}, &block)
  development_data =
    options.delete(:development_data) ||
    options.delete(:data)

  config = Options.new(options, &block)
  new(config).train(training_data, development_data)
end

Instance Method Details

#check(input) ⇒ Object



56
57
58
59
60
# File 'lib/wapiti/model.rb', line 56

def check(input)
  reset
  label input, check: true
  stats
end

#compactObject



696
697
698
699
# File 'ext/wapiti/native.c', line 696

static VALUE model_compact(VALUE self) {
  mdl_compact(get_model(self));
  return self;
}

#label(data) ⇒ Object



1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
# File 'ext/wapiti/native.c', line 1088

def label(input, opts = nil, &block)
  unless opts.nil?
    original_options = options.attributes(opts.keys)
    options.update!(opts)
  end

  input = input.to_a(tagged: options.check) if input.is_a?(Dataset)

  if block_given?
    output = native_label(input, &block)
  else
    output = native_label(input)
  end

  return output if options.nbest > 1 || options.skip_tokens

  Dataset.parse output, tagged: true
ensure
  unless original_options.nil?
    options.update(original_options)
  end
end

#labelsObject

Returns a sorted list of all labels in the Model’s label database.



907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
# File 'ext/wapiti/native.c', line 907

static VALUE model_labels(VALUE self) {
  mdl_t *model = get_model(self);
  const uint32_t Y = model->nlbl;

  qrk_t *lp = model->reader->lbl;

  VALUE labels = rb_ary_new2(Y);
  for (unsigned int i = 0; i < Y; ++i) {
    rb_ary_push(labels, rb_str_new2(qrk_id2str(lp, i)));
  }

  rb_funcall(labels, rb_intern("sort!"), 0);

  return labels;
}

#load(*args) ⇒ Object



735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
# File 'ext/wapiti/native.c', line 735

static VALUE model_load(int argc, VALUE *argv, VALUE self) {
  if (argc > 1) {
    rb_raise(cArgumentError,
      "wrong number of arguments (%d for 0..1)", argc);
  }

  mdl_t *model = get_model(self);

  // save passed-in argument in options
  if (argc) {
    Check_Type(argv[0], T_STRING);
    rb_ivar_set(self, rb_intern("@path"), argv[0]);
  }

  // open the model file
  VALUE path = rb_ivar_get(self, rb_intern("@path"));

  if (NIL_P(path)) {
    fatal("failed to load model: no path given");
  }

  FILE *file = ufopen(path, "r");
  mdl_load(model, file);
  fclose(file);

  return self;
}

#nftrObject Also known as: features



684
685
686
# File 'ext/wapiti/native.c', line 684

static VALUE model_nftr(VALUE self) {
  return INT2FIX(get_model(self)->nftr);
}

#nlblObject

Native accessors



676
677
678
# File 'ext/wapiti/native.c', line 676

static VALUE model_nlbl(VALUE self) {
  return INT2FIX(get_model(self)->nlbl);
}

#nobsObject Also known as: observations



680
681
682
# File 'ext/wapiti/native.c', line 680

static VALUE model_nobs(VALUE self) {
  return INT2FIX(get_model(self)->nobs);
}

#patternObject



23
24
25
# File 'lib/wapiti/model.rb', line 23

def pattern
  options.pattern
end

#pattern=(filename) ⇒ Object



27
28
29
# File 'lib/wapiti/model.rb', line 27

def pattern=(filename)
  options.pattern = filename
end

#reset_countersObject Also known as: reset



94
95
96
97
98
99
100
# File 'lib/wapiti/model.rb', line 94

def reset_counters
  @token_count = 0
  @token_errors = 0
  @sequence_count = 0
  @sequence_errors = 0
  self
end

#save(*args) ⇒ Object

otherwise uses the passed-in argument as the Model’s path.



707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
# File 'ext/wapiti/native.c', line 707

static VALUE model_save(int argc, VALUE *argv, VALUE self) {
  if (argc > 1) {
    rb_raise(cArgumentError,
      "wrong number of arguments (%d for 0..1)", argc);
  }

  mdl_t *model = get_model(self);

  // save passed-in path in options
  if (argc) {
    Check_Type(argv[0], T_STRING);
    rb_ivar_set(self, rb_intern("@path"), argv[0]);
  }

  // open the output file
  VALUE path = rb_ivar_get(self, rb_intern("@path"));

  if (NIL_P(path)) {
    fatal("failed to save model: no path given");
  }

  FILE *file = ufopen(path, "w");
  mdl_save(model, file);
  fclose(file);

  return self;
}

#sequence_error_rateObject



109
110
111
112
# File 'lib/wapiti/model.rb', line 109

def sequence_error_rate
  return 0 if sequence_errors.zero?
  sequence_errors / sequence_count.to_f * 100.0
end

#statisticsObject Also known as: stats



77
78
79
80
81
82
83
84
85
86
87
88
89
90
# File 'lib/wapiti/model.rb', line 77

def statistics
  {
    token: {
      count: token_count,
      errors: token_errors,
      rate: token_error_rate
    },
    sequence: {
      count: sequence_count,
      errors: sequence_errors,
      rate: sequence_error_rate
    }
  }
end

#syncObject

Instance methods



691
692
693
694
# File 'ext/wapiti/native.c', line 691

static VALUE model_sync(VALUE self) {
  mdl_sync(get_model(self));
  return self;
}

#token_error_rateObject



104
105
106
107
# File 'lib/wapiti/model.rb', line 104

def token_error_rate
  return 0 if token_errors.zero?
  token_errors / token_count.to_f * 100.0
end

#train(train, devel) ⇒ Object



835
836
837
838
839
840
841
842
843
844
845
846
# File 'ext/wapiti/native.c', line 835

def train(tdat, ddat = nil, opts = nil, &block)
  options.update!(opts) unless opts.nil?

  tdat = tdat.to_a(tagged: true) if tdat.is_a?(Dataset)
  ddat = ddat.to_a(tagged: true) if ddat.is_a?(Dataset)

  if block_given?
    native_train(tdat, ddat, &block)
  else
    native_train(tdat, ddat)
  end
end