Java 本地调用连接 Intel 数学核心库
纯 Java 语言实在算不上一种高性能的数值计算语言,当然抽象层次越高,效率一般越低,也没什么好抱怨的。如果有相当大的数值计算需求,那么我们可以将这部分代码用 C/C++ 来实现,再通过 Java 来调用,将能很好地提升 Java 程序的性能。
Intel@reg Math Kernel Library 是 Intel 公司开发的高性能数学函数库,作为例子,这里我们只讨论 BLAS 这块,其他部分根据说明文档和例子去操作就可以了。BLAS 全称是 Basic Linear Algebra Subprograms,即基本线性代数子程序,它本身只定义了三个层次的线性代数计算函数接口,其中
- 是向量与向量的计算,例如点积、向量求和等,
- 是矩阵与向量的计算,如矩阵向量乘法等,
- 是矩阵与矩阵的计算,比如矩阵乘法等。
可以发现函数接口的层次恰好等同于这个函数的时间复杂度,比如向量点积的时间复杂度是 \(O(n)\), 矩阵乘向量的是 \(O(n^2)\),矩阵乘法的是 \(O(n^3)\)(当然,一些优化算法的时间复杂度确实要稍微低一点,但这里只是说的一般情况)。
下面我们以 DGEMM 函数为例作说明,这个函数进行如下操作
\[C \leftarrow \alpha op(A) op(B) +\beta C\]其中 op 是对矩阵的转置函数,可以设置转置,也可以不转置,\(\alpha, \beta\) 是标量系数。函数参数列表如下
参数名称 | 解释说明 |
layout | 矩阵存储风格 C 或者 Fortran |
transA | 矩阵 A 是否要转置 |
transB | 矩阵 B 是否要转置 |
m | 矩阵 A 的行数 |
n | 矩阵 B 的列数 |
k | 矩阵 A 的列数 |
alpha | 系数 |
A | 矩阵 A 的指针 |
LDA | 矩阵 A 的列数 |
B | 矩阵 B 的指针 |
LDB | 矩阵 B 的列数 |
beta | 系数 |
C | 矩阵 C 的指针 |
LDC | 矩阵 C 的列数 |
为了在 Java 中使用该函数,首先我们定义本地方法
public class JBlas{
public static native void dgemm(int layout, int transA, int transB, int m, int n, int k, double alpha, double[] A, int LDA, double[] B, int LDB, double beta, double[] C, int LDC);
}
然后使用 javac 编译并生成头文件
javac JBlas.java -h .\native\src
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class com_haswalk_jblas_JBlas */
#ifndef _Included_com_haswalk_jblas_JBlas
#define _Included_com_haswalk_jblas_JBlas
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: com_haswalk_jblas_JBlas
* Method: dgemm
* Signature: (IIIIIID[DI[DID[DI)V
*/
JNIEXPORT void JNICALL Java_com_haswalk_jblas_JBlas_dgemm
(JNIEnv *, jclass, jint, jint, jint, jint, jint, jint, jdouble, jdoubleArray, jint, jdoubleArray, jint, jdouble, jdoubleArray, jint);
#ifdef __cplusplus
}
#endif
#endif
然后使用 C 语言实现头文件定义的函数,其中使用 cblas_dgemm 函数作为调用接口
#include <jni.h>
#include <assert.h>
#include "mkl_cblas.h"
#ifndef _Included_com_haswalk_jblas_JBlas
#define _Included_com_haswalk_jblas_JBlas
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: com_haswalk_jblas_JBlas
* Method: dgemm
* Signature: (IIIIIID[DI[DID[DI)V
*/
JNIEXPORT void JNICALL Java_com_haswalk_jblas_JBlas_dgemm
(JNIEnv * env, jclass clazz, jint order, jint transA, jint transB, jint m, jint n, jint k,
jdouble alpha, jdoubleArray A, jint LDA, jdoubleArray B, jint LDB, jdouble beta, jdoubleArray C, jint LDC){
jdouble *aElems, *bElems, *cElems;
//将 Java 数组拷贝到 C 数组中
aElems = (*env) -> GetDoubleArrayElements(env, A, NULL);
bElems = (*env) -> GetDoubleArrayElements(env, B, NULL);
cElems = (*env) -> GetDoubleArrayElements(env, C, NULL);
assert(aElems && bElems && cElems);
//计算
cblas_dgemm((CBLAS_ORDER) order, (CBLAS_TRANSPOSE) transA,
(CBLAS_TRANSPOSE) transB, (int)m, (int)n, (int)k, alpha, aElems, (int)LDA, bElems, (int)LDB, beta, cElems, (int)LDC);
//释放资源
(*env) -> ReleaseDoubleArrayElements(env, C, cElems, 0);
(*env) -> ReleaseDoubleArrayElements(env, B, bElems, JNI_ABORT);
(*env) -> ReleaseDoubleArrayElements(env, A, aElems, JNI_ABORT);
}
#ifdef __cplusplus
}
#endif
#endif
然后编译成动态链接库
gcc -c .\native\src\com_haswalk_jblas_JBlas.c -o .\native\src\com_haswalk_jblas_JBlas.o -I "%MKL_ROOT%\include" -I "%JAVA_HOME%\include" -I "%JAVA_HOME%\include\win32"
gcc -shared -o .\native\lib\com_haswalk_jblas_JBlas.dll .\native\src\com_haswalk_jblas_JBlas.o -L "%MKL_ROOT%\lib\intel64_win" -lmkl_rt
在 Java 中调用 dgemm 这个函数之前先加载动态链接库
System.load("path/to/com_haswalk_jblas_JBlas.dll");
然后我们可以做一个小例子
@Test
public void testdgemm() {
int m = 4, n = 5, k = 3;
int LDA = k, LDB = n, LDC = n;
double[] A = {1,2,3,
4,2,2,
3,4,5,
4,2,1};
double[] B = {1,2,3,6,6,
3,1,5,1,7,
4,6,8,8,10};
double[] C = new double[m * n];
double alpha = 1;
double beta = 0;
//101, 111 分别表示行存储,不转置,其定义可以在 mkl 的 include 文件夹下的 mkl_cblas.h 中找到
JBlas.dgemm(101, 111, 111,
m, n, k, alpha, A, LDA, B, LDB, beta, C, LDC);
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
System.out.print(C[j + i * n] + "\t");
}
System.out.println();
}
}
/*
output:
19.0 22.0 37.0 32.0 50.0
18.0 22.0 38.0 42.0 58.0
35.0 40.0 69.0 62.0 96.0
14.0 16.0 30.0 34.0 48.0
*/
最后我们做一下性能测试,来对比一下 Java 和 C 语言调用 MKL 函数的加速效果。
首先是 C 的测试代码
#include <stdio.h>
#include <stdlib.h>
#include "mkl.h"
int main()
{
double *A, *B, *C;
int m, n, p, i, j;
double alpha, beta;
m = 2000, p = 200, n = 1000;
alpha = 1.0; beta = 0.0;
A = (double *)mkl_malloc( m*p*sizeof( double ), 64 );
B = (double *)mkl_malloc( p*n*sizeof( double ), 64 );
C = (double *)mkl_malloc( m*n*sizeof( double ), 64 );
if (A == NULL || B == NULL || C == NULL) {
printf( "\n ERROR: Can't allocate memory for matrices. Aborting... \n\n");
mkl_free(A);
mkl_free(B);
mkl_free(C);
return 1;
}
printf (" Intializing matrix data \n\n");
for (i = 0; i < (m*p); i++) {
A[i] = (double)(i+1);
}
for (i = 0; i < (p*n); i++) {
B[i] = (double)(-i-1);
}
for (i = 0; i < (m*n); i++) {
C[i] = 0.0;
}
printf (" Computing matrix product using Intel(R) MKL dgemm function via CBLAS interface \n\n");
printf(" Warmup...\n\n");
for(i = 0; i < 10; i++) {
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
m, n, p, alpha, A, p, B, n, beta, C, n);
}
double s_initial, s_elapsed, total;
printf(" Loop...\n\n");
for(int i = 0; i < 10; i++) {
s_initial = dsecnd();
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
m, n, p, alpha, A, p, B, n, beta, C, n);
s_elapsed = dsecnd() - s_initial;
total += s_elapsed;
printf(" Iteration %d: %5f\n", i, s_elapsed * 1000);
}
printf("\n Average time: %5f\n", total /10.0 * 1000);
printf ("\n Computations completed.\n\n");
mkl_free(A);
mkl_free(B);
mkl_free(C);
return 0;
}
/*
Iteration 0: 36870.120442
Iteration 1: 34055.884578
Iteration 2: 41175.512597
Iteration 3: 39219.469414
Iteration 4: 37622.348522
Iteration 5: 31426.735572
Iteration 6: 31998.701277
Iteration 7: 33684.410388
Iteration 8: 38041.815511
Iteration 9: 35845.068516
Average time: 35994.006682
*/
Java 基准测试代码,这里我们使用 JMH 来做。
import com.haswalk.jblas.JBlas;
import com.haswalk.jblas.JBlasLayout;
import com.haswalk.jblas.JBlasTranspose;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Thread)
public class DgemmBenchmark {
private double[] A, B, C;
private int m, n, k;
private int LDA, LDB, LDC;
public DgemmBenchmark() {
m = 2000;
n = 200;
k = 1000;
LDA = k;
LDB = n;
LDC = n;
A = new double[m * k];
B = new double[k * n];
for(int i = 0; i < m * LDA; i++) {
A[i] = i+1;
}
for(int i = 0; i < k*LDB; i++) {
B[i] = -(i-1);
}
C = new double[m * n];
}
@Benchmark
public void run() {
double alpha = 1;
double beta = 0;
JBlas.dgemm(JBlasLayout.ROW_MAJOR, JBlasTranspose.NO_TRANS, JBlasTranspose.NO_TRANS,
m, n, k, alpha, A, LDA, B, LDB, beta, C, LDC);
}
public static void main(String[] args) throws RunnerException {
Options opt = new OptionsBuilder()
.include(DgemmBenchmark.class.getSimpleName())
.forks(1)
.build();
new Runner(opt).run();
}
}
/*
output(main):
Iteration 1: 59721.053 us/op
Iteration 2: 72918.741 us/op
Iteration 3: 66438.956 us/op
Iteration 4: 66108.340 us/op
Iteration 5: 65751.479 us/op
Iteration 6: 66027.411 us/op
Iteration 7: 56218.227 us/op
Iteration 8: 54677.101 us/op
Iteration 9: 53554.030 us/op
Iteration 10: 57047.576 us/op
Benchmark Mode Cnt Score Error Units
DgemmBenchmark.run avgt 20 58016.823 ± 5164.820 us/op
*/
从上面的测试可以看到,使用 Java 调用 MKL 函数,对比起直接用 C 还是有接近一倍的性能差距,这应该是 JNI 本身的调用开销造成的,但是对比起直接在 Java 里面写三层循环已经是相当不错了。