426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
|
# File 'lib/numo/narray/extra.rb', line 426
def concatenate(arrays,axis:0)
klass = (self==NArray) ? NArray.array_type(arrays) : self
nd = 0
arrays = arrays.map do |a|
case a
when NArray
when Numeric
a = klass[a]
when Array
a = klass.cast(a)
else
raise TypeError,"not Numo::NArray: #{a.inspect[0..48]}"
end
if a.ndim > nd
nd = a.ndim
end
a
end
if axis < 0
axis += nd
end
if axis < 0 || axis >= nd
raise ArgumentError,"axis is out of range"
end
new_shape = nil
sum_size = 0
arrays.each do |a|
a_shape = a.shape
if nd != a_shape.size
a_shape = [1]*(nd-a_shape.size) + a_shape
end
sum_size += a_shape.delete_at(axis)
if new_shape
if new_shape != a_shape
raise ShapeError,"shape mismatch"
end
else
new_shape = a_shape
end
end
new_shape.insert(axis,sum_size)
result = klass.zeros(*new_shape)
lst = 0
refs = [true] * nd
arrays.each do |a|
fst = lst
lst = fst + (a.shape[axis-nd]||1)
if lst > fst
refs[axis] = fst...lst
result[*refs] = a
end
end
result
end
|