44import com .knuddels .jtokkit .api .Encoding ;
55import com .knuddels .jtokkit .api .EncodingRegistry ;
66import com .knuddels .jtokkit .api .EncodingType ;
7+ import com .knuddels .jtokkit .api .IntArrayList ;
78import com .knuddels .jtokkit .api .ModelType ;
89import com .launchableinc .openai .completion .chat .ChatMessage ;
910import lombok .AllArgsConstructor ;
@@ -46,7 +47,7 @@ public class TikTokensUtil {
4647 * @return Encoding array
4748 */
4849 public static List <Integer > encode (Encoding enc , String text ) {
49- return isBlank (text ) ? new ArrayList <>() : enc .encode (text );
50+ return isBlank (text ) ? new ArrayList <>() : enc .encode (text ). boxed () ;
5051 }
5152
5253 /**
@@ -69,7 +70,7 @@ public static int tokens(Encoding enc, String text) {
6970 * @return Text information corresponding to the encoding array.
7071 */
7172 public static String decode (Encoding enc , List <Integer > encoded ) {
72- return enc .decode (encoded );
73+ return enc .decode (toIntArrayList ( encoded ) );
7374 }
7475
7576 /**
@@ -94,7 +95,7 @@ public static List<Integer> encode(EncodingType encodingType, String text) {
9495 return new ArrayList <>();
9596 }
9697 Encoding enc = getEncoding (encodingType );
97- List <Integer > encoded = enc .encode (text );
98+ List <Integer > encoded = enc .encode (text ). boxed () ;
9899 return encoded ;
99100 }
100101
@@ -119,7 +120,7 @@ public static int tokens(EncodingType encodingType, String text) {
119120 */
120121 public static String decode (EncodingType encodingType , List <Integer > encoded ) {
121122 Encoding enc = getEncoding (encodingType );
122- return enc .decode (encoded );
123+ return enc .decode (toIntArrayList ( encoded ) );
123124 }
124125
125126
@@ -147,7 +148,7 @@ public static List<Integer> encode(String modelName, String text) {
147148 if (Objects .isNull (enc )) {
148149 return new ArrayList <>();
149150 }
150- List <Integer > encoded = enc .encode (text );
151+ List <Integer > encoded = enc .encode (text ). boxed () ;
151152 return encoded ;
152153 }
153154
@@ -209,7 +210,16 @@ public static int tokens(String modelName, List<ChatMessage> messages) {
209210 */
210211 public static String decode (String modelName , List <Integer > encoded ) {
211212 Encoding enc = getEncoding (modelName );
212- return enc .decode (encoded );
213+ return enc .decode (toIntArrayList (encoded ));
214+ }
215+
216+ private static IntArrayList toIntArrayList (List <Integer > encoded ) {
217+ IntArrayList intArrayList = new IntArrayList (encoded .size ());
218+ for (Integer e : encoded ) {
219+ intArrayList .add (e );
220+ }
221+
222+ return intArrayList ;
213223 }
214224
215225
0 commit comments