deeplearning4j-1.ND4J

简介

NDArray本质上是一个n维数组, 如果你学过python,我们可以认为就是java版的Numpy。

一些概念

名词 含义 说明
rank 维度数 二维数组的rank为2,三维数组的rank为3
shape 形状 几行几列的数组
length 长度 数组中元素的总数, 即 行数* 列数
type 数据类型 默认:float

存储

INDArray背后的数据是堆外存储的:也就是说,它存储在Java虚拟机(JVM)之外。

编码

可以按C(行主要)或Fortran(列主要)顺序对NDArray进行编码。

创建NDArrays

创建方式

zeros

创建一个2行3列,元素值为0的ndarray

INDArray zeros = Nd4j.zeros(2, 3);
System.out.println(zeros);

console:
[[         0,         0,         0], 
 [         0,         0,         0]]

ones

创建一个2行3列,元素值为1的ndarray

INDArray ones = Nd4j.ones(2, 3);
System.out.println(ones);

console:
[[    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000]]

rand

创建一个2行3列,元素值随机的ndarray

 INDArray rand = Nd4j.rand(2, 3);
 System.out.println(rand);
 
 console:
 [[    0.1140,    0.1990,    0.4751], 
 [    0.1058,    0.2520,    0.7749]]

randn

创建一个2行3列,元素值随机, 服从高斯分布(平均值为0, 方差为1)的的ndarray

INDArray randn = Nd4j.randn(2, 3);
System.out.println(randn);

console:
[[   -0.0883,    0.9208,    0.2775], 
 [   -2.0135,    2.9719,   -0.4583]]

create

根据数组生成

INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
System.out.println(indArray);

console:
[[    1.0000,    2.0000,    3.0000], 
 [    4.0000,    5.0000,    6.0000]]

获取值的方式

指定下标

INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);

float aFloat = indArray.getFloat(1, 2);
System.out.println(aFloat);

console:
6

获取行和列

INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
INDArray column = indArray.getColumn(1); // 索引为1的列
System.out.println(column);

INDArray row = indArray.getRow(1);  // 索引为1的行
System.out.println(row);

console:
[    2.0000,    5.0000]
[    4.0000,    5.0000,    6.0000]

修改值

通过索引修改

INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
indArray.putScalar(1,1,10);
System.out.println(indArray);

console:
[[    1.0000,    2.0000,    3.0000], 
 [    4.0000,   10.0000,    6.0000]]

修改整行 或 整列

INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
indArray.putRow(0, Nd4j.create(new float[]{10,20,30}, 1, 3));
System.out.println(indArray);

indArray.putColumn(1, Nd4j.create(new float[]{200,500}, 2, 1));
System.out.println(indArray);


console:
[[   10.0000,   20.0000,   30.0000], 
 [    4.0000,    5.0000,    6.0000]]
[[   10.0000,  200.0000,   30.0000], 
 [    4.0000,  500.0000,    6.0000]]

运算

参考 6.5 矩阵的运算及其运算规则

加减法

两个矩阵相加减,即它们相同位置的元素相加减!

所以相加减的矩阵必须行数列数相等

 INDArray ndArr1 = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
 INDArray ndArr2 = Nd4j.create(new float[]{6,5,4,3,2,1}, 2, 3);
 INDArray result = ndArr1.add(ndArr2);
 INDArray result1 = ndArr1.sub(ndArr2);
 System.out.println(ndArr1);
 System.out.println(ndArr2);
 System.out.println("--加法--");
 System.out.println(result);
 System.out.println("--减法--");
 System.out.println(result1);

console:
[[    1.0000,    2.0000,    3.0000], 
 [    4.0000,    5.0000,    6.0000]]
 
[[    6.0000,    5.0000,    4.0000], 
 [    3.0000,    2.0000,    1.0000]]
--加法--
[[    7.0000,    7.0000,    7.0000], 
 [    7.0000,    7.0000,    7.0000]]
--减法--
[[   -5.0000,   -3.0000,   -1.0000], 
 [    1.0000,    3.0000,    5.0000]]

矩阵的加减满足: 交换律结合律

交换律: A + B = B + A

结合律: (A + B) + C = A + (B + C)

乘法

当矩阵A的列数(column)等于矩阵B的行数(row)时,A与B可以相乘, 结果矩阵的形状为 (A的行数, B的列数)

INDArray ndArr1 = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
INDArray ndArr2 = Nd4j.create(new float[]{1,2,3,4,5,6}, 3, 2);
System.out.println(ndArr1);
System.out.println(ndArr2);
INDArray mmul = ndArr1.mmul(ndArr2);
System.out.println(mmul);

onsole:
[[    1.0000,    2.0000,    3.0000], 
 [    4.0000,    5.0000,    6.0000]]

[[    1.0000,    2.0000], 
 [    3.0000,    4.0000], 
 [    5.0000,    6.0000]]

[[   22.0000,   28.0000], 
 [   49.0000,   64.0000]]

除法

INDArray ndArr1 = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
INDArray ndArr2 = Nd4j.create(new float[]{2,4,6,8,10,12}, 2, 3);
INDArray result = ndArr1.div(ndArr2);
System.out.println(result);

console:
[[    0.5000,    0.5000,    0.5000], 
 [    0.5000,    0.5000,    0.5000]]

备注

  1. 本文来源DL4J中文文档/ND4J/概述