MENU

JVM深度学习实践概述

November 6, 2020 • 瞎折腾

10 月用一整月写完了一篇论文,因此也没能兑现月更的承诺,实属丢人现眼。本文将总结 10 月以及最近这几天的心得与体会,并希望能够给更多遭遇同样境况的人一些方向。

展开目录

尽管现在在深度学习方面最火热的编程语言是 Python,但作为高贵的 Java 的程序员,使用全动态的 Python 语言有害于我的纯粹性。因此在前些日子赶论文的时候,一大部分时间都花在了使用 JVM 进行深度学习上。从一开始的无数次 Python 入门失败,到今天能够写成文章总结 JVM 上进行深度学习的经验,我认为这比写论文更有意思,毕竟 Java 天下第一,在应用领域我还是最支持 JVM。

警告

本文对 Python 持有强烈的偏见,如有不适请及时关闭网页。

引言展开目录

尽管序和引言在中文颇有重复的嫌疑,但即便如此我还是要将他们分开。在本节我将试图阐述为什么 Python 做机器学习很火,而 Java 上的深度学习一直不温不火。这里需要强调的是本文所说的深度学习,是通常需要显卡进行训练的任务,而不是传统的机器学习。在传统机器学习领域,我认为 Java 的 Weka 库和 Oracle 的 tribuo 都已经做的非常优秀了。

关于 Python 在机器学习领域的火爆,网络上有那么几种说法,我听到的一种说法是 CPython 对于 C 系的库调用起来很方便,因为机器学习库底层都是 C/C++ 写的,比较快,而 Java 通过 JavaCPP 或者 JNI 调用普遍比较慢。但是我觉得这种说法并不稳妥,我认为更实际的原因可能是 Python 好学,毕竟深度学习的核心在于数学,而搞数学的人在以前能写个 Matlab 就算不错了,而 Python 又没有什么严格的工程结构,王八蛋都能学会,故因为便于入门,许多搞理论研究的人使用 Python 构建他们通往实践的路。另一方面 Java 则没有这么好入门,各种工程设计的模式以及强类型的束缚,以及 JVM 的深不可测,使得一些 Java 程序员写了十几年代码也不敢说自己真正了解 Java,因此 Java 才在深度学习领域不温不火吧。

另外值得一说的是 Python 本身只是一个编程语言规范,而本文所指的 Python 实际上是其基于 C/C++ 的实现,即 CPython,在以前也曾经有过 JVM 实现的 Python,但是其版本止步于 2.7,而后来 Oracle 的 GraaleVM 也试图打通 Python 与 JVM 的隔阂,但就现阶段来说 GraalVM 上通过 Python 调用 C 的 API 还仍然在开发中。

Support for more extension modules is high priority for us. We are actively building out our support for the Python C API to make extensions such as NumPy, SciPy, Scikit-learn, Pandas, Tensorflow and the like work fully. This work means that some other extensions might also already work, but we're not actively testing other extensions right now and cannot promise anything. Note that to try other extensions on this implementation, you have to download, build, and install them manually for now.

支持更过拓展模块是我们最先要做的。我们正在努力开发 Python 的 C 语言 API,使得诸如 NumPy、SciPy、Scikit-learn、Pandas、Tensorflow 之类的框架完全正常工作。这些工作意味着其他一些扩展也可能已经起作用,但是我们现在不积极测试这些扩展,也无法做出任何保证或承诺。 请注意,要尝试此实现的其他扩展,您现在必须手动下载,构建和安装它们。

来源:Github - GraalVM

翻译过来就是 GraalVM 的 Python 支持在做,也许能用,但是不稳定。尽管现阶段没有任何应用的可能(我不希望在代码中引入任何不稳定性),但是我还是很看好 GraalVM,除了 Python,它还有一个 Native Image,能够将 Java 等其支持的语言编译为静态镜像,十分值得期待。

但话说回来,DDL 不等人,我们没法等着 GraalVM 开发完毕再去使用 Java 调用 Python 的代码搞深度学习。因此本文将就现阶段可用的,且好用的、符合 Java 程序员习惯的一些深度学习框架进行介绍与对比。

最后还需要提一点,Python 上面的深度学习库依赖于其底层的 C++ 库,原则上可以通过 JNI 调用他们去做 JVM 的深度学习,但是很显然这有不是很优雅,我也不认同他是 JVM 上的深度学习。但后续介绍的一些库也使用,甚至是直接使用的 Python 深度学习库的 C++ 库,但那样我认为可以是 JVM 上的深度学习。至于这二者之间的区别,还请读者自行拿捏。

不适用的选项展开目录

本节介绍一些对我现有环境不可用的选项,这些选项或许可以在其他环境中正常工作。我的环境是 Windows 10.

MXNet Scala Edition 展开目录

Apache 的 MXNet 在主要面向 Python 语言之外,还额外的支持了 7 种语言,其中 JVM 语言占了 3 种,分别是 Java、Scala 和 Clojure。其中 Java 语言的 API 只能用于现有模型的部署,不能用于训练,Scala 语言的 API 可以进行训练,但只支持 CPU,而 Clojure 版本的 API 则是对 Scala API 的再封装。

我选择使用 Scala 版本,尽管我对 SBT 一无所知,但还是在学习之初(在 Hello World 之前)给官方提了一个 Issue。后来发现 Scala 版本的 API 不支持 Windows,只好作罢。

Tensorflow Scala 展开目录

接着 Scala 的热度,自然想到的就是 Tensorflow 在 Scala 上的非官方封装。对于封装的 Scala API 本身并没有什么平台之分,而关键在于 Tensorflow 的动态库,对于预编译版本,并没有 Windows 的选项,所以只能从 Windows 上本地编译 Tensorflow。而 Windows 上编译,反正我是不看好。倘若真的需要在 Windows 上去编译库,我宁愿再装一个 Linux 双系统也不要安装微软的 Visual Studio。所以这个选项也只好作罢。

Tensorflow Java 0.2 展开目录

如果 Scala 的路走不通,那就剩下官方的 Java API 了。Tensorflow 的 Java API 最近发布了 0.2 版,该版本支持使用 Java 进行模型训练。但是目前来说文档很少,只有官方的一页,而且 Tensorflow 在 2.0 版本全面拥抱 Keras,使得 Java 版本的 Tensorflow 几乎只能从最底层的做起。尽管有 tensorflow-framework 起到类似 Keras 的作用,但目前来说还十分不完善。或许还得等它发展一些时间吧。于是 Tensorflow Java 也只能作罢。

适用的选项展开目录

DeepLearning4J 展开目录

DL4J 是目前我实践最多的一个库,这个库完全不依赖 Python 库(但底层还是 C++ 实现)的、存粹的商业深度学习库。是的,它是一个商业库,背后的公司曾经是 SkyMind,后来变更为 konduit.ai。它具有所有其他可以商业使用的库的全部特点:诸如完善的文档,规范的代码组织结构,积极的社区支持等,总的来说用于开发项目十分顺手,十分符合一个 Java 程序员的使用习惯。

但另一方面,该库对于自定义拓展并不是很友好。目前该框架中有一个 SameDiff,旨在模型中引入你自定义的行为,相对于 PyTorch 之类的针对普通函数的自动微分,Java 实现中最困难的还是类型系统。但是好在这种困难并非只针对于这个框架,回想一下 Spring 框架中为了迎合 MVC 架构,一个 Hello World 的网页都要写 5 个或者更多的类才能实现,在深度学习上只不过是把论文里的公式实现柔和进了继承与覆盖的过程中罢了。即便如此,我还是很开心能够用这个库完成我的论文,比使用 Python 快乐多了。

为了自定义一个 Attention 层,首先你需要声明 WorkSpace,告诉框架你这个层都需要哪些变量作为参数,这样框架能够根据环境自动在内存或显存种帮你管理这些变量。之后定义运算过程,这个定义只能使用框架提供的有限数学操作,但是你可以使用循环或分支等语句动态的根据参数内容变化最终组合而成的表达式,之后反向传播的时候,框架会根据已知的操作和代码给出的组合自动计算导数等参数。所以其中最困难的地方就是如何面对抽象的对象,将其想象为实际的数据,并利用有限的操作来实现你所要的功能。在这方面我还不是太行,所以至今尚未能实现 CRF 层(大约半年前),转而使用了 Attention。但是随着一篇论文的实践,之后我或许可以做到了。

对于这个框架,如果需要出于研究目的使用的话,其自带的一些东西很可能几乎不能够覆盖到你的需求,你需要自己实现。而实现的过程又极度依赖于 Java 的开发模式,虽然使用 Kotlin 在语法上简化了相当一部分工作,但如果对于 Java 本身并不熟练,要拓展出自己的函数还是十分困难的。而且就目前来说,如何在这个框架上实现一个 Transformer Block 我还毫无头绪。因此我大概是不会选择使用这个框架进行研究性工作的。但如果需要开发商业或工业的项目,这些项目的重点不在于模型有多新,而是要求程序有多稳,那么这个框架绝对会是我的首选。

Deep Java Library 展开目录

DJL 是亚马逊开发的深度学习库。由于其稀少的文档,使得我在写论文之初选择了 DL4J,直到论文完结后我才开始重新审视这个库。

就这个库本身的设计来说,其底层依赖于四个 Python 深度学习库之一,可以任选 Apache MXNet、PyTorch、TensorFlow 和 ONNX Runtime 之一作为底层,而 DJL 则提供了一个统一的抽象,我选择 PyTorch。在使用时只需要配置好对应的 maven 依赖,完全不需要自己构建什么,maven 依赖中会包含对应底层所需要的 C++ 实现库,因此和 DL4J 一样无痛使用。当然了,前提是你能找到官方的这些 maven 包的名称。

是的,文档是目前为止这个库最痛苦的地方。你以为在 Setup 一节中会告诉你如何在 Maven 或者 Gradle 中使用这个库?想都别想,只有一段这样的文字:

(Optional) Import using Gradle/Maven wrappers

You use Gradle and Maven wrappers to build the project, so you don't need to install Gradle or Maven. However, you should have basic knowledge about the Gradle or Maven build system.

嗯?那些依赖的包名呢???好在这网页有搜索功能,直接搜索 maven,找到了相关页面

总之使用这个库还是挺重度依赖搜索功能的,无论是搜索文档还是搜索谷歌,总归还是有些困难。但是对于库本身来说,因为其底层直接使用各种 Python 库的依赖,因此在使用上更倾向于已有深度学习库的模样,当然在抽象的过程中会有一些改变,但整体而言不像 DL4J 那样另起炉灶。在扩展自定义操作方面,可以直接使用 LambdaBlock,虽然一样要和 DL4J 操作一个抽象的对象,但至少在工作量上轻松很多。另外值得一提的是,这个库将模型表示为 Block 的组合,用起来会和 Keras 的 Sequential 十分相似。

昨天摸索了一下午,终于拼凑出来了该库的 MNIST 程序。

依赖如下(使用 Gradle 6.7):

  • implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.4.10"
  • implementation group: 'ch.qos.logback', name: 'logback-classic', version: '1.2.3'
  • implementation platform("ai.djl:bom:0.8.0")
  • implementation 'ai.djl:api:0.8.0'
  • implementation "ai.djl:basicdataset:0.8.0"
  • implementation "ai.djl:model-zoo:0.8.0"
  • runtimeOnly 'ai.djl.pytorch:pytorch-engine:0.8.0'
  • // runtimeOnly 'ai.djl.pytorch:pytorch-native-auto:1.6.0'
  • runtimeOnly 'ai.djl.pytorch:pytorch-native-cu102:1.6.0:win-x86_64'

训练程序如下(使用 Kotlin):

  • package info.skyblond.djl.demo
  • import ai.djl.Model
  • import ai.djl.basicdataset.Mnist
  • import ai.djl.ndarray.types.Shape
  • import ai.djl.nn.Activation
  • import ai.djl.nn.Blocks
  • import ai.djl.nn.SequentialBlock
  • import ai.djl.nn.core.Linear
  • import ai.djl.training.DefaultTrainingConfig
  • import ai.djl.training.EasyTrain
  • import ai.djl.training.evaluator.Accuracy
  • import ai.djl.training.listener.TrainingListener
  • import ai.djl.training.loss.Loss
  • import java.nio.file.Files
  • import java.nio.file.Paths
  • fun main() {
  • Model.newInstance("MNIST").use { model ->
  • val batchSize = 32
  • val mnist = Mnist.builder()
  • .setSampling(batchSize, true)
  • .build()
  • val input = 28 * 28
  • val output = 10
  • val hiddenLayerSizes = intArrayOf(128, 64)
  • val sequentialBlock = SequentialBlock()
  • sequentialBlock.add(Blocks.batchFlattenBlock(input.toLong()))
  • for (hiddenSize in hiddenLayerSizes) {
  • sequentialBlock.add(Linear.builder().setUnits(hiddenSize.toLong()).build())
  • sequentialBlock.add(Activation::relu)
  • }
  • sequentialBlock.add(Linear.builder().setUnits(output.toLong()).build())
  • model.block = sequentialBlock
  • val config = DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
  • .addEvaluator(Accuracy())
  • .addTrainingListeners(*TrainingListener.Defaults.logging())
  • val trainer = model.newTrainer(config)
  • trainer.initialize(Shape(1, 28 * 28))
  • val epoch = 2
  • for (i in 0 until epoch) {
  • for (batch in trainer.iterateDataset(mnist)) {
  • // During trainBatch, we update the loss and evaluators with the results for the training batch.
  • EasyTrain.trainBatch(trainer, batch)
  • // Now, we update the model parameters based on the results of the latest trainBatch
  • trainer.step()
  • // We must make sure to close the batch to ensure all the memory associated with the batch is cleared quickly.
  • // If the memory isn't closed after each batch, you will very quickly run out of memory on your GPU
  • batch.close()
  • }
  • // Call the end epoch event for the training listeners now that we are done
  • trainer.notifyListeners { listener: TrainingListener ->
  • listener.onEpoch(trainer)
  • }
  • }
  • val modelDir = Paths.get("build/mlp")
  • Files.createDirectories(modelDir)
  • model.setProperty("Epoch", epoch.toString())
  • model.save(modelDir, "mlp")
  • }
  • }

推理(即利用模型做预测)程序(使用 Kotlin):

  • package info.skyblond.djl.demo
  • import ai.djl.Model
  • import ai.djl.basicmodelzoo.basic.Mlp
  • import ai.djl.modality.Classifications
  • import ai.djl.modality.cv.Image
  • import ai.djl.modality.cv.ImageFactory
  • import ai.djl.modality.cv.util.NDImageUtils
  • import ai.djl.ndarray.NDArray
  • import ai.djl.ndarray.NDList
  • import ai.djl.translate.Batchifier
  • import ai.djl.translate.Translator
  • import ai.djl.translate.TranslatorContext
  • import java.nio.file.Paths
  • import java.util.stream.Collectors
  • import java.util.stream.IntStream
  • fun main() {
  • val img = ImageFactory.getInstance().fromUrl("https://djl-ai.s3.amazonaws.com/resources/images/0.png");
  • val modelDir = Paths.get("build/mlp")
  • val model = Model.newInstance("mlp")
  • model.block = Mlp(28 * 28, 10, intArrayOf(128, 64))
  • model.load(modelDir)
  • val translator: Translator<Image, Classifications> = object : Translator<Image, Classifications> {
  • override fun processInput(ctx: TranslatorContext, input: Image): NDList {
  • // Convert Image to NDArray
  • val array: NDArray = input.toNDArray(ctx.ndManager, Image.Flag.GRAYSCALE)
  • return NDList(NDImageUtils.toTensor(array))
  • }
  • override fun processOutput(ctx: TranslatorContext?, list: NDList): Classifications? {
  • // Create a Classifications with the output probabilities
  • val probabilities = list.singletonOrThrow().softmax(0)
  • val classNames = IntStream.range(0, 10).mapToObj { i: Int -> i.toString() }.collect(Collectors.toList())
  • return Classifications(classNames, probabilities)
  • }
  • // The Batchifier describes how to combine a batch together
  • // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
  • override fun getBatchifier(): Batchifier {
  • return Batchifier.STACK
  • }
  • }
  • val predictor = model.newPredictor(translator)
  • val classifications = predictor.predict(img)
  • println(classifications.toString())
  • }

之后我会更倾向于使用这个库去做一些研究性的东西,从结构上看它会比 DL4J 更灵活一些。尽管在使用的过程中会有一些阻碍,但好在并不耽误使用,而且代码本身拥有详细的注释,因此一旦引用了 Maven 依赖,便可以在 IDE 中通过阅读注释等方法解决半数以上的问题,余下的一半几乎都可以在搜索引擎中得到解决。

小结展开目录

DL4J 和 DJL 都能够很方便的通过 Maven 或 Gradle 构建深度学习环境,而不需要依赖 Python 运行环境。至此我认为这两个库已经能够很好的满足在 JVM 上进行深度学习的需求了,但是除了深度学习本身,关于数据可视化以及深度学习附属的一些领域还差一些。文中没有提到的是 DL4J 可以通过官方支持使用网页来查看训练过程中的一些数据,诸如当前迭代的分数、更新参数的比例等,有助于调整超参。并且 DL4J 对于训练和评估的过程十分完备,除了最普通的准确率(Accuracy)之外,还有精确度(Precision)、召回率和 F1 的计算,而 DJL 尚未发现有网页 UI,而训练过程中的评估,官方也只给了 Accuracy 一种。但好在 DJL 也有规范的接口,可以通过继承 Evaluator 来实现自己的评估器。

尽管还需要一些时间才能更加成熟,无论如何,能够在 JVM 进行深度学习都是一件很值得高兴的事情。


知识共享许可协议
JVM 深度学习实践概述天空 Blond 采用 知识共享 署名 - 非商业性使用 - 相同方式共享 4.0 国际 许可协议进行许可。
本许可协议授权之外的使用权限可以从 https://skyblond.info/about.html 处获得。

Archives QR Code
QR Code for this page
Tipping QR Code
Leave a Comment

4 Comments
  1. ALways ALways

    23 年的我,过来补一下课

  2. 星座占卜 星座占卜

    文章写的很好啊,赞 (ㆆᴗㆆ),每日打卡~~

  3. 计算机专业的同学的世界真是好丰富多彩啊... 我现在每天反复洗瓶子过柱子实在是头疼...

    1. @NiceBowl 其实也没那么丰富多彩,一开始说写论文挺有意思,可是做到后面各种理论与实际之间巨大的隔阂,甚至好几次都劝退,到最后理论到实践有了,还得考虑工程结构,各种性能与易用性优化。如果不把这东西当成自己的兴趣爱好的话,我觉得根本坚持不下来。(顺带一提,本来一开始报名的有 5 个人,到最后结题只剩下一个了 2333