Skip to content

Commit 92df1ad

Browse files
committed
add DJL text similarity example
1 parent 9954bd2 commit 92df1ad

File tree

4 files changed

+153
-0
lines changed

4 files changed

+153
-0
lines changed

docs/images/textsimularityheatmap.png

50.4 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
apply plugin: 'groovy'
17+
apply plugin: 'application'
18+
19+
repositories {
20+
mavenCentral()
21+
}
22+
23+
ext.appName = 'UniversalSentenceEncoder'
24+
25+
application {
26+
mainClass = appName
27+
}
28+
29+
tasks.named('run').configure {
30+
description = "Run $appName as a JVM application/Groovy script"
31+
}
32+
33+
dependencies {
34+
implementation "ai.djl:api:$djlVersion"
35+
implementation "org.apache.groovy:groovy:$groovy4Version"
36+
implementation "com.github.haifengl:smile-plot:$smileVersion"
37+
implementation "com.github.haifengl:smile-math:$smileVersion"
38+
runtimeOnly "ai.djl.tensorflow:tensorflow-engine:$djlVersion"
39+
runtimeOnly "ai.djl.tensorflow:tensorflow-model-zoo:$djlVersion"
40+
runtimeOnly "ai.djl.tensorflow:tensorflow-native-auto:2.4.1"
41+
runtimeOnly "org.slf4j:slf4j-jdk14:$slf4jVersion"
42+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!--
2+
SPDX-License-Identifier: Apache-2.0
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
# Language processing with DJL and TensorFlow
18+
19+
Neural networks with numerous layers of nodes allow for more complex, rich and _deeper_ processing and understanding.
20+
This example detects objects within an image.
21+
It uses a pre-trained model and the
22+
[Deep Java Library](https://djl.ai/) backed by the
23+
[TensorFlow](https://www.tensorflow.org/) engine.
24+
25+
![MXNet.groovy](../../docs/images/textsimularityheatmap.png)
26+
27+
Groovy code examples can be found in the [src/main/groovy](src/main/groovy) subdirectory.
28+
If you have opened the repo in IntelliJ (or your favourite IDE) you should be able to execute the examples directly in the IDE.
29+
30+
__Requirements__: The code has been tested on JDK8, JDK11 and JDK17.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import ai.djl.Application
2+
import ai.djl.ndarray.NDArrays
3+
import ai.djl.ndarray.NDList
4+
import ai.djl.repository.zoo.Criteria
5+
import ai.djl.training.util.ProgressBar
6+
import ai.djl.translate.NoBatchifyTranslator
7+
import ai.djl.translate.TranslatorContext
8+
import smile.plot.swing.Heatmap
9+
import smile.plot.swing.Palette
10+
11+
import static smile.math.MathEx.dot
12+
13+
/*
14+
* An example of inference using an universal sentence encoder model from TensorFlow Hub.
15+
* For more info see: https://tfhub.dev/google/universal-sentence-encoder/4
16+
* Inspired by: https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/UniversalSentenceEncoder.java
17+
*/
18+
19+
class MyTranslator implements NoBatchifyTranslator<String[], double[][]> {
20+
@Override
21+
NDList processInput(TranslatorContext ctx, String[] raw) {
22+
var factory = ctx.NDManager
23+
var inputs = new NDList(raw.collect(factory::create))
24+
new NDList(NDArrays.stack(inputs))
25+
}
26+
27+
@Override
28+
double[][] processOutput(TranslatorContext ctx, NDList list) {
29+
long numOutputs = list.singletonOrThrow().shape.get(0)
30+
NDList result = []
31+
for (i in 0..<numOutputs) {
32+
result << list.singletonOrThrow().get(i)
33+
}
34+
result*.toFloatArray() as double[][]
35+
}
36+
}
37+
38+
def predict(String[] inputs) {
39+
String modelUrl = "https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/4.tar.gz"
40+
41+
Criteria<String[], double[][]> criteria =
42+
Criteria.builder()
43+
.optApplication(Application.NLP.TEXT_EMBEDDING)
44+
.setTypes(String[], double[][])
45+
.optModelUrls(modelUrl)
46+
.optTranslator(new MyTranslator())
47+
.optEngine("TensorFlow")
48+
.optProgress(new ProgressBar())
49+
.build()
50+
try (var model = criteria.loadModel()
51+
var predictor = model.newPredictor()) {
52+
predictor.predict(inputs)
53+
}
54+
}
55+
String[] inputs = [
56+
"Cycling is low impact and great for cardio",
57+
"Swimming is low impact and good for fitness",
58+
"Palates is good for fitness and flexibility",
59+
"Weights are good for strength and fitness",
60+
"Orchids can be tricky to grow",
61+
"Sunflowers are fun to grow",
62+
"Radishes are easy to grow",
63+
"The taste of radishes grows on you after a while",
64+
]
65+
var k = inputs.size()
66+
67+
var embeddings = predict(inputs)
68+
69+
def z = new double[k][k]
70+
for (i in 0..<k) {
71+
println "Embedding for: ${inputs[i]}\n${Arrays.toString(embeddings[i])}"
72+
for (j in 0..<k) {
73+
z[i][j] = dot(embeddings[i], embeddings[j])
74+
}
75+
}
76+
77+
new Heatmap(inputs, inputs, z, Palette.heat(20).reverse()).canvas().with {
78+
title = 'Semantic textual similarity'
79+
setAxisLabels('', '')
80+
window()
81+
}

0 commit comments

Comments
 (0)