Class: TensorFlow::Data::Dataset

Inherits:
Object
  • Object
show all
Includes:
Enumerable
Defined in:
lib/tensorflow/data/dataset.rb

Direct Known Subclasses

BatchDataset, ShuffleDataset, TensorSliceDataset

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(variant_tensor) ⇒ Dataset

Returns a new instance of Dataset.



9
10
11
# File 'lib/tensorflow/data/dataset.rb', line 9

def initialize(variant_tensor)
  @variant_tensor = variant_tensor
end

Instance Attribute Details

#output_shapesObject (readonly)

TODO remove



7
8
9
# File 'lib/tensorflow/data/dataset.rb', line 7

def output_shapes
  @output_shapes
end

#output_typesObject (readonly)

TODO remove



7
8
9
# File 'lib/tensorflow/data/dataset.rb', line 7

def output_types
  @output_types
end

Class Method Details

.from_tensor_slices(tensors) ⇒ Object



21
22
23
# File 'lib/tensorflow/data/dataset.rb', line 21

def self.from_tensor_slices(tensors)
  TensorSliceDataset.new(tensors)
end

Instance Method Details

#batch(batch_size, drop_remainder: false) ⇒ Object



13
14
15
# File 'lib/tensorflow/data/dataset.rb', line 13

def batch(batch_size, drop_remainder: false)
  BatchDataset.new(self, batch_size, drop_remainder)
end

#eachObject



29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# File 'lib/tensorflow/data/dataset.rb', line 29

def each
  iterator, deleter = RawOps.anonymous_iterator_v2(output_types: @output_types, output_shapes: @output_shapes)
  RawOps.make_iterator(dataset: @variant_tensor, iterator: iterator)
  begin
    loop do
      values = RawOps.iterator_get_next_sync(iterator: iterator, output_types: @output_types, output_shapes: @output_shapes)
      yield values
    end
  rescue Error => e
    # iterate until end of sequence error
    raise e unless e.message == "End of sequence"
  end
ensure
  RawOps.delete_iterator(handle: iterator, deleter: deleter) if iterator
end

#shuffle(buffer_size) ⇒ Object



17
18
19
# File 'lib/tensorflow/data/dataset.rb', line 17

def shuffle(buffer_size)
  ShuffleDataset.new(self, buffer_size)
end

#to_ptrObject



25
26
27
# File 'lib/tensorflow/data/dataset.rb', line 25

def to_ptr
  @variant_tensor.to_ptr
end