简介
NDArray本质上是一个n维数组, 如果你学过python,我们可以认为就是java版的Numpy。
一些概念
名词 | 含义 | 说明 |
---|---|---|
rank | 维度数 | 二维数组的rank为2,三维数组的rank为3 |
shape | 形状 | 几行几列的数组 |
length | 长度 | 数组中元素的总数, 即 行数* 列数 |
type | 数据类型 | 默认:float |
存储
INDArray背后的数据是堆外存储的:也就是说,它存储在Java虚拟机(JVM)之外。
编码
可以按C(行主要)或Fortran(列主要)顺序对NDArray进行编码。
创建NDArrays
创建方式
zeros
创建一个2行3列,元素值为0的ndarray
INDArray zeros = Nd4j.zeros(2, 3);
System.out.println(zeros);
console:
[[ 0, 0, 0],
[ 0, 0, 0]]
ones
创建一个2行3列,元素值为1的ndarray
INDArray ones = Nd4j.ones(2, 3);
System.out.println(ones);
console:
[[ 1.0000, 1.0000, 1.0000],
[ 1.0000, 1.0000, 1.0000]]
rand
创建一个2行3列,元素值随机的ndarray
INDArray rand = Nd4j.rand(2, 3);
System.out.println(rand);
console:
[[ 0.1140, 0.1990, 0.4751],
[ 0.1058, 0.2520, 0.7749]]
randn
创建一个2行3列,元素值随机, 服从高斯分布(平均值为0, 方差为1)的的ndarray
INDArray randn = Nd4j.randn(2, 3);
System.out.println(randn);
console:
[[ -0.0883, 0.9208, 0.2775],
[ -2.0135, 2.9719, -0.4583]]
create
根据数组生成
INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
System.out.println(indArray);
console:
[[ 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000]]
获取值的方式
指定下标
INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
float aFloat = indArray.getFloat(1, 2);
System.out.println(aFloat);
console:
6
获取行和列
INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
INDArray column = indArray.getColumn(1); // 索引为1的列
System.out.println(column);
INDArray row = indArray.getRow(1); // 索引为1的行
System.out.println(row);
console:
[ 2.0000, 5.0000]
[ 4.0000, 5.0000, 6.0000]
修改值
通过索引修改
INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
indArray.putScalar(1,1,10);
System.out.println(indArray);
console:
[[ 1.0000, 2.0000, 3.0000],
[ 4.0000, 10.0000, 6.0000]]
修改整行 或 整列
INDArray indArray = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
indArray.putRow(0, Nd4j.create(new float[]{10,20,30}, 1, 3));
System.out.println(indArray);
indArray.putColumn(1, Nd4j.create(new float[]{200,500}, 2, 1));
System.out.println(indArray);
console:
[[ 10.0000, 20.0000, 30.0000],
[ 4.0000, 5.0000, 6.0000]]
[[ 10.0000, 200.0000, 30.0000],
[ 4.0000, 500.0000, 6.0000]]
运算
加减法
两个矩阵相加减,即它们相同位置的元素相加减!
所以相加减的矩阵必须行数和列数相等
INDArray ndArr1 = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
INDArray ndArr2 = Nd4j.create(new float[]{6,5,4,3,2,1}, 2, 3);
INDArray result = ndArr1.add(ndArr2);
INDArray result1 = ndArr1.sub(ndArr2);
System.out.println(ndArr1);
System.out.println(ndArr2);
System.out.println("--加法--");
System.out.println(result);
System.out.println("--减法--");
System.out.println(result1);
console:
[[ 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000]]
[[ 6.0000, 5.0000, 4.0000],
[ 3.0000, 2.0000, 1.0000]]
--加法--
[[ 7.0000, 7.0000, 7.0000],
[ 7.0000, 7.0000, 7.0000]]
--减法--
[[ -5.0000, -3.0000, -1.0000],
[ 1.0000, 3.0000, 5.0000]]
矩阵的加减满足: 交换律和结合律
交换律: A + B = B + A
结合律: (A + B) + C = A + (B + C)
乘法
当矩阵A的列数(column)等于矩阵B的行数(row)时,A与B可以相乘, 结果矩阵的形状为 (A的行数, B的列数)
INDArray ndArr1 = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
INDArray ndArr2 = Nd4j.create(new float[]{1,2,3,4,5,6}, 3, 2);
System.out.println(ndArr1);
System.out.println(ndArr2);
INDArray mmul = ndArr1.mmul(ndArr2);
System.out.println(mmul);
onsole:
[[ 1.0000, 2.0000, 3.0000],
[ 4.0000, 5.0000, 6.0000]]
[[ 1.0000, 2.0000],
[ 3.0000, 4.0000],
[ 5.0000, 6.0000]]
[[ 22.0000, 28.0000],
[ 49.0000, 64.0000]]
除法
INDArray ndArr1 = Nd4j.create(new float[]{1,2,3,4,5,6}, 2, 3);
INDArray ndArr2 = Nd4j.create(new float[]{2,4,6,8,10,12}, 2, 3);
INDArray result = ndArr1.div(ndArr2);
System.out.println(result);
console:
[[ 0.5000, 0.5000, 0.5000],
[ 0.5000, 0.5000, 0.5000]]
备注
- 本文来源DL4J中文文档/ND4J/概述