Skip to content

Commit 3a78e04

Browse files
committed
fix breaking change
1 parent ffdbddd commit 3a78e04

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

api/src/main/java/com/launchableinc/openai/utils/TikTokensUtil.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.knuddels.jtokkit.api.Encoding;
55
import com.knuddels.jtokkit.api.EncodingRegistry;
66
import com.knuddels.jtokkit.api.EncodingType;
7+
import com.knuddels.jtokkit.api.IntArrayList;
78
import com.knuddels.jtokkit.api.ModelType;
89
import com.launchableinc.openai.completion.chat.ChatMessage;
910
import lombok.AllArgsConstructor;
@@ -46,7 +47,7 @@ public class TikTokensUtil {
4647
* @return Encoding array
4748
*/
4849
public static List<Integer> encode(Encoding enc, String text) {
49-
return isBlank(text) ? new ArrayList<>() : enc.encode(text);
50+
return isBlank(text) ? new ArrayList<>() : enc.encode(text).boxed();
5051
}
5152

5253
/**
@@ -69,7 +70,7 @@ public static int tokens(Encoding enc, String text) {
6970
* @return Text information corresponding to the encoding array.
7071
*/
7172
public static String decode(Encoding enc, List<Integer> encoded) {
72-
return enc.decode(encoded);
73+
return enc.decode(toIntArrayList(encoded));
7374
}
7475

7576
/**
@@ -94,7 +95,7 @@ public static List<Integer> encode(EncodingType encodingType, String text) {
9495
return new ArrayList<>();
9596
}
9697
Encoding enc = getEncoding(encodingType);
97-
List<Integer> encoded = enc.encode(text);
98+
List<Integer> encoded = enc.encode(text).boxed();
9899
return encoded;
99100
}
100101

@@ -119,7 +120,7 @@ public static int tokens(EncodingType encodingType, String text) {
119120
*/
120121
public static String decode(EncodingType encodingType, List<Integer> encoded) {
121122
Encoding enc = getEncoding(encodingType);
122-
return enc.decode(encoded);
123+
return enc.decode(toIntArrayList(encoded));
123124
}
124125

125126

@@ -147,7 +148,7 @@ public static List<Integer> encode(String modelName, String text) {
147148
if (Objects.isNull(enc)) {
148149
return new ArrayList<>();
149150
}
150-
List<Integer> encoded = enc.encode(text);
151+
List<Integer> encoded = enc.encode(text).boxed();
151152
return encoded;
152153
}
153154

@@ -209,7 +210,16 @@ public static int tokens(String modelName, List<ChatMessage> messages) {
209210
*/
210211
public static String decode(String modelName, List<Integer> encoded) {
211212
Encoding enc = getEncoding(modelName);
212-
return enc.decode(encoded);
213+
return enc.decode(toIntArrayList(encoded));
214+
}
215+
216+
private static IntArrayList toIntArrayList(List<Integer> encoded) {
217+
IntArrayList intArrayList = new IntArrayList(encoded.size());
218+
for (Integer e : encoded) {
219+
intArrayList.add(e);
220+
}
221+
222+
return intArrayList;
213223
}
214224

215225

0 commit comments

Comments
 (0)