Fenrier Lab

Java 本地调用连接 Intel 数学核心库

纯 Java 语言实在算不上一种高性能的数值计算语言,当然抽象层次越高,效率一般越低,也没什么好抱怨的。如果有相当大的数值计算需求,那么我们可以将这部分代码用 C/C++ 来实现,再通过 Java 来调用,将能很好地提升 Java 程序的性能。

Intel@reg Math Kernel Library 是 Intel 公司开发的高性能数学函数库,作为例子,这里我们只讨论 BLAS 这块,其他部分根据说明文档和例子去操作就可以了。BLAS 全称是 Basic Linear Algebra Subprograms,即基本线性代数子程序,它本身只定义了三个层次的线性代数计算函数接口,其中

  1. 是向量与向量的计算,例如点积、向量求和等,
  2. 是矩阵与向量的计算,如矩阵向量乘法等,
  3. 是矩阵与矩阵的计算,比如矩阵乘法等。

可以发现函数接口的层次恰好等同于这个函数的时间复杂度,比如向量点积的时间复杂度是 \(O(n)\), 矩阵乘向量的是 \(O(n^2)\),矩阵乘法的是 \(O(n^3)\)(当然,一些优化算法的时间复杂度确实要稍微低一点,但这里只是说的一般情况)。

下面我们以 DGEMM 函数为例作说明,这个函数进行如下操作

\[C \leftarrow \alpha op(A) op(B) +\beta C\]

其中 op 是对矩阵的转置函数,可以设置转置,也可以不转置,\(\alpha, \beta\) 是标量系数。函数参数列表如下

参数名称 解释说明
layout 矩阵存储风格 C 或者 Fortran
transA 矩阵 A 是否要转置
transB 矩阵 B 是否要转置
m 矩阵 A 的行数
n 矩阵 B 的列数
k 矩阵 A 的列数
alpha 系数
A 矩阵 A 的指针
LDA 矩阵 A 的列数
B 矩阵 B 的指针
LDB 矩阵 B 的列数
beta 系数
C 矩阵 C 的指针
LDC 矩阵 C 的列数

为了在 Java 中使用该函数,首先我们定义本地方法

public class JBlas{

	public static native void dgemm(int layout, int transA, int transB, int m, int n, int k, double alpha, double[] A, int LDA, double[] B, int LDB, double beta, double[] C, int LDC);

}

然后使用 javac 编译并生成头文件

javac JBlas.java -h .\native\src
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class com_haswalk_jblas_JBlas */

#ifndef _Included_com_haswalk_jblas_JBlas
#define _Included_com_haswalk_jblas_JBlas
#ifdef __cplusplus
extern "C" {
#endif

/*
 * Class:     com_haswalk_jblas_JBlas
 * Method:    dgemm
 * Signature: (IIIIIID[DI[DID[DI)V
 */
JNIEXPORT void JNICALL Java_com_haswalk_jblas_JBlas_dgemm
  (JNIEnv *, jclass, jint, jint, jint, jint, jint, jint, jdouble, jdoubleArray, jint, jdoubleArray, jint, jdouble, jdoubleArray, jint);

#ifdef __cplusplus
}
#endif
#endif

然后使用 C 语言实现头文件定义的函数,其中使用 cblas_dgemm 函数作为调用接口


#include <jni.h>
#include <assert.h>
#include "mkl_cblas.h"

#ifndef _Included_com_haswalk_jblas_JBlas
#define _Included_com_haswalk_jblas_JBlas
#ifdef __cplusplus
extern "C" {
#endif

/*
 * Class:     com_haswalk_jblas_JBlas
 * Method:    dgemm
 * Signature: (IIIIIID[DI[DID[DI)V
 */
JNIEXPORT void JNICALL Java_com_haswalk_jblas_JBlas_dgemm
  (JNIEnv * env, jclass clazz, jint order, jint transA, jint transB, jint m, jint n, jint k, 
  jdouble alpha, jdoubleArray A, jint LDA, jdoubleArray B, jint LDB, jdouble beta, jdoubleArray C, jint LDC){
	  
	jdouble *aElems, *bElems, *cElems;
	//将 Java 数组拷贝到 C 数组中
	aElems = (*env) -> GetDoubleArrayElements(env, A, NULL);
	bElems = (*env) -> GetDoubleArrayElements(env, B, NULL);
	cElems = (*env) -> GetDoubleArrayElements(env, C, NULL);
	assert(aElems && bElems && cElems);
	//计算
	cblas_dgemm((CBLAS_ORDER) order, (CBLAS_TRANSPOSE) transA, 
	(CBLAS_TRANSPOSE) transB, (int)m, (int)n, (int)k, alpha, aElems, (int)LDA, bElems, (int)LDB, beta, cElems, (int)LDC);
	//释放资源
	(*env) -> ReleaseDoubleArrayElements(env, C, cElems, 0);
	(*env) -> ReleaseDoubleArrayElements(env, B, bElems, JNI_ABORT);
	(*env) -> ReleaseDoubleArrayElements(env, A, aElems, JNI_ABORT);
  }

#ifdef __cplusplus
}
#endif
#endif

然后编译成动态链接库

gcc -c .\native\src\com_haswalk_jblas_JBlas.c -o .\native\src\com_haswalk_jblas_JBlas.o -I "%MKL_ROOT%\include" -I "%JAVA_HOME%\include" -I "%JAVA_HOME%\include\win32"

gcc -shared -o .\native\lib\com_haswalk_jblas_JBlas.dll .\native\src\com_haswalk_jblas_JBlas.o -L "%MKL_ROOT%\lib\intel64_win" -lmkl_rt

在 Java 中调用 dgemm 这个函数之前先加载动态链接库

System.load("path/to/com_haswalk_jblas_JBlas.dll");

然后我们可以做一个小例子

@Test
    public void testdgemm() {
        int m = 4, n = 5, k = 3;
        int LDA = k, LDB = n, LDC = n;
        double[] A = {1,2,3,
                        4,2,2,
                        3,4,5,
                        4,2,1};
        double[] B = {1,2,3,6,6,
                        3,1,5,1,7,
                        4,6,8,8,10};
        double[] C = new double[m * n];

        double alpha = 1;
        double beta = 0;
		//101, 111 分别表示行存储,不转置,其定义可以在 mkl 的 include 文件夹下的 mkl_cblas.h 中找到
        JBlas.dgemm(101, 111, 111,
                m, n, k, alpha, A, LDA, B, LDB, beta, C, LDC);
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print(C[j + i * n] + "\t");
            }
            System.out.println();
        }
    }
	/*
	output:
	19.0	22.0	37.0	32.0	50.0	
	18.0	22.0	38.0	42.0	58.0	
	35.0	40.0	69.0	62.0	96.0	
	14.0	16.0	30.0	34.0	48.0
	*/

最后我们做一下性能测试,来对比一下 Java 和 C 语言调用 MKL 函数的加速效果。

首先是 C 的测试代码

#include <stdio.h>
#include <stdlib.h>
#include "mkl.h"

int main()
{
    double *A, *B, *C;
    int m, n, p, i, j;
    double alpha, beta;

    m = 2000, p = 200, n = 1000;
    alpha = 1.0; beta = 0.0;

    A = (double *)mkl_malloc( m*p*sizeof( double ), 64 );
    B = (double *)mkl_malloc( p*n*sizeof( double ), 64 );
    C = (double *)mkl_malloc( m*n*sizeof( double ), 64 );
    if (A == NULL || B == NULL || C == NULL) {
        printf( "\n ERROR: Can't allocate memory for matrices. Aborting... \n\n");
        mkl_free(A);
        mkl_free(B);
        mkl_free(C);
        return 1;
    }

    printf (" Intializing matrix data \n\n");
    for (i = 0; i < (m*p); i++) {
        A[i] = (double)(i+1);
    }

    for (i = 0; i < (p*n); i++) {
        B[i] = (double)(-i-1);
    }

    for (i = 0; i < (m*n); i++) {
        C[i] = 0.0;
    }
	
    printf (" Computing matrix product using Intel(R) MKL dgemm function via CBLAS interface \n\n");
	
	printf(" Warmup...\n\n");
	for(i = 0; i < 10; i++) {
		cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, 
                m, n, p, alpha, A, p, B, n, beta, C, n);
	}
	
	double s_initial, s_elapsed, total;
	printf(" Loop...\n\n");
	
	for(int i = 0; i < 10; i++) {
		s_initial = dsecnd();
		cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, 
                m, n, p, alpha, A, p, B, n, beta, C, n);
		s_elapsed = dsecnd() - s_initial;
		total += s_elapsed;
		printf(" Iteration %d: %5f\n", i, s_elapsed * 1000);

	}
	printf("\n Average time: %5f\n", total /10.0 * 1000);
	printf ("\n Computations completed.\n\n");

    mkl_free(A);
    mkl_free(B);
    mkl_free(C);

    return 0;
}
/*
 Iteration 0: 36870.120442
 Iteration 1: 34055.884578
 Iteration 2: 41175.512597
 Iteration 3: 39219.469414
 Iteration 4: 37622.348522
 Iteration 5: 31426.735572
 Iteration 6: 31998.701277
 Iteration 7: 33684.410388
 Iteration 8: 38041.815511
 Iteration 9: 35845.068516

 Average time: 35994.006682
*/

Java 基准测试代码,这里我们使用 JMH 来做。

import com.haswalk.jblas.JBlas;
import com.haswalk.jblas.JBlasLayout;
import com.haswalk.jblas.JBlasTranspose;
import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;

import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Thread)
public class DgemmBenchmark {

    private double[] A, B, C;
    private int m, n, k;
    private int LDA, LDB, LDC;
    public DgemmBenchmark() {
        m = 2000;
        n = 200;
        k = 1000;
        LDA = k;
        LDB = n;
        LDC = n;

        A = new double[m * k];
        B = new double[k * n];
        for(int i = 0; i < m * LDA; i++) {
            A[i] = i+1;
        }
        for(int i = 0; i < k*LDB; i++) {
            B[i] = -(i-1);
        }

        C = new double[m * n];
    }

    @Benchmark
    public void run() {
        double alpha = 1;
        double beta = 0;
        JBlas.dgemm(JBlasLayout.ROW_MAJOR, JBlasTranspose.NO_TRANS, JBlasTranspose.NO_TRANS,
                m, n, k, alpha, A, LDA, B, LDB, beta, C, LDC);
    }

    public static void main(String[] args) throws RunnerException {
        Options opt = new OptionsBuilder()
                .include(DgemmBenchmark.class.getSimpleName())
                .forks(1)
                .build();
        new Runner(opt).run();
    }
}
/*
output(main):
Iteration   1: 59721.053 us/op
Iteration   2: 72918.741 us/op
Iteration   3: 66438.956 us/op
Iteration   4: 66108.340 us/op
Iteration   5: 65751.479 us/op
Iteration   6: 66027.411 us/op
Iteration   7: 56218.227 us/op
Iteration   8: 54677.101 us/op
Iteration   9: 53554.030 us/op
Iteration  10: 57047.576 us/op

Benchmark           Mode  Cnt      Score      Error  Units
DgemmBenchmark.run  avgt   20  58016.823 ± 5164.820  us/op
*/

从上面的测试可以看到,使用 Java 调用 MKL 函数,对比起直接用 C 还是有接近一倍的性能差距,这应该是 JNI 本身的调用开销造成的,但是对比起直接在 Java 里面写三层循环已经是相当不错了。

本文遵守 CC-BY-NC-4.0 许可协议。

Creative Commons License

欢迎转载,转载需注明出处,且禁止用于商业目的。