11/*
2- * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
2+ * Copyright (c) 2025, 2026, Oracle and/or its affiliates. All rights reserved.
33 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44 *
55 * This code is free software; you can redistribute it and/or modify it
@@ -43,32 +43,36 @@ public final class LlamaModel {
4343 VOCAB_SIZE = 128256 ,
4444 HEAD_SIZE = 64 ,
4545 HIDEN_SIZE = 2048 ,
46+ KV_HIDEN_SIZE = 512 ,
4647 CONTEXT_SIZE = 131072 ,
4748 INTERMEDIATE_SIZE = 8192 ,
4849 ATTN_WEIGHTS_SIZE = 3072 ;
4950 public static final float EPSILON = 1.0E-5f ,
5051 SCALE = 0.125f ;
5152
5253 public final Tensor <Long > flat1 , scalar1 ;
53- public final Tensor <Float > tokensWeights , initWeight , cosCache , sinCache , headScales ;
54+ public final Tensor <Float > tokensWeights , initWeight , cosCache , sinCache ;
5455 public final Tensor <Float >[] postAttentionWeights = new Tensor [LAYERS ],
5556 inputWeights = new Tensor [LAYERS ],
56- attnQkvScales = new Tensor [LAYERS ],
57+ attnQScales = new Tensor [LAYERS ],
58+ attnKScales = new Tensor [LAYERS ],
59+ attnVScales = new Tensor [LAYERS ],
5760 attnOScales = new Tensor [LAYERS ],
5861 mlpGateScales = new Tensor [LAYERS ],
5962 mlpUpScales = new Tensor [LAYERS ],
6063 mlpDownScales = new Tensor [LAYERS ];
61- public final Tensor <Byte >[] attnQkvWeight = new Tensor [LAYERS ],
64+ public final Tensor <Byte >[] attnQWeight = new Tensor [LAYERS ],
65+ attnKWeight = new Tensor [LAYERS ],
66+ attnVWeight = new Tensor [LAYERS ],
6267 attnOWeight = new Tensor [LAYERS ],
6368 mlpGateWeight = new Tensor [LAYERS ],
6469 mlpUpWeight = new Tensor [LAYERS ],
6570 mlpDownWeight = new Tensor [LAYERS ];
66- public final Tensor <Byte > headWeight ;
6771
6872 public LlamaModel (Arena arena ) throws IOException {
6973 flat1 = Tensor .ofFlat (arena , 1l );
7074 scalar1 = Tensor .ofScalar (arena , 1l );
71- var modelData = new TensorDataStream (arena , LlamaModel .class .getResource ("model.onnx.data " ).getPath ());
75+ var modelData = new TensorDataStream (arena , LlamaModel .class .getResource ("model_q4.onnx_data " ).getPath ());
7276 tokensWeights = modelData .nextTensor (FLOAT , VOCAB_SIZE , HIDEN_SIZE );
7377 initWeight = modelData .nextTensor (FLOAT , HIDEN_SIZE );
7478 cosCache = modelData .nextTensor (FLOAT , CONTEXT_SIZE , HEAD_SIZE / 2 );
@@ -78,19 +82,21 @@ public LlamaModel(Arena arena) throws IOException {
7882 inputWeights [i ] = modelData .nextTensor (FLOAT , HIDEN_SIZE );
7983 }
8084 for (int i = 0 ; i < LAYERS ; i ++) {
81- attnQkvWeight [i ] = modelData .nextTensor (UINT8 , ATTN_WEIGHTS_SIZE , HEAD_SIZE , 16 );
82- attnQkvScales [i ] = modelData .nextTensor (FLOAT , ATTN_WEIGHTS_SIZE * HEAD_SIZE );
85+ attnQWeight [i ] = modelData .nextTensor (UINT8 , HIDEN_SIZE , HEAD_SIZE , 16 );
86+ attnQScales [i ] = modelData .nextTensor (FLOAT , HIDEN_SIZE , HEAD_SIZE );
87+ attnKWeight [i ] = modelData .nextTensor (UINT8 , KV_HIDEN_SIZE , HEAD_SIZE , 16 );
88+ attnKScales [i ] = modelData .nextTensor (FLOAT , KV_HIDEN_SIZE , HEAD_SIZE );
89+ attnVWeight [i ] = modelData .nextTensor (UINT8 , KV_HIDEN_SIZE , HEAD_SIZE , 16 );
90+ attnVScales [i ] = modelData .nextTensor (FLOAT , KV_HIDEN_SIZE , HEAD_SIZE );
8391 attnOWeight [i ] = modelData .nextTensor (UINT8 , HIDEN_SIZE , HEAD_SIZE , 16 );
84- attnOScales [i ] = modelData .nextTensor (FLOAT , HIDEN_SIZE * HEAD_SIZE );
92+ attnOScales [i ] = modelData .nextTensor (FLOAT , HIDEN_SIZE , HEAD_SIZE );
8593 mlpGateWeight [i ] = modelData .nextTensor (UINT8 , INTERMEDIATE_SIZE , HEAD_SIZE , 16 );
86- mlpGateScales [i ] = modelData .nextTensor (FLOAT , INTERMEDIATE_SIZE * HEAD_SIZE );
94+ mlpGateScales [i ] = modelData .nextTensor (FLOAT , INTERMEDIATE_SIZE , HEAD_SIZE );
8795 mlpUpWeight [i ] = modelData .nextTensor (UINT8 , INTERMEDIATE_SIZE , HEAD_SIZE , 16 );
88- mlpUpScales [i ] = modelData .nextTensor (FLOAT , INTERMEDIATE_SIZE * HEAD_SIZE );
96+ mlpUpScales [i ] = modelData .nextTensor (FLOAT , INTERMEDIATE_SIZE , HEAD_SIZE );
8997 mlpDownWeight [i ] = modelData .nextTensor (UINT8 , HIDEN_SIZE , 256 , 16 );
90- mlpDownScales [i ] = modelData .nextTensor (FLOAT , INTERMEDIATE_SIZE * HEAD_SIZE );
98+ mlpDownScales [i ] = modelData .nextTensor (FLOAT , HIDEN_SIZE , 256 );
9199 }
92- headWeight = modelData .nextTensor (UINT8 , VOCAB_SIZE , HEAD_SIZE , 16 );
93- headScales = modelData .nextTensor (FLOAT , VOCAB_SIZE * HEAD_SIZE );
94100 }
95101
96102 public record ForwardResponse (Tensor <Float > logits ,
@@ -110,12 +116,15 @@ public ForwardResponse forward(Tensor<Long> inputIds, Tensor<Long> attentionMask
110116 Tensor <Float >[] presentValues = new Tensor [LAYERS ];
111117
112118 for (int i = 0 ; i < LAYERS ; i ++) {
113- GroupQueryAttention <Float > attn = GroupQueryAttention (
114- MatMulNBits (input ,
115- attnQkvWeight [i ],
116- attnQkvScales [i ], empty (), empty (), empty (), HIDEN_SIZE , ATTN_WEIGHTS_SIZE , of (ACCURACY_LEVEL ), BITS , BLOCK_SIZE ),
117- empty (),
118- empty (),
119+ GroupQueryAttention <Float > attn = GroupQueryAttention (MatMulNBits (input ,
120+ attnQWeight [i ],
121+ attnQScales [i ], empty (), empty (), empty (), HIDEN_SIZE , HIDEN_SIZE , of (ACCURACY_LEVEL ), BITS , BLOCK_SIZE ),
122+ of (MatMulNBits (input ,
123+ attnKWeight [i ],
124+ attnKScales [i ], empty (), empty (), empty (), HIDEN_SIZE , KV_HIDEN_SIZE , of (ACCURACY_LEVEL ), BITS , BLOCK_SIZE )),
125+ of (MatMulNBits (input ,
126+ attnVWeight [i ],
127+ attnVScales [i ], empty (), empty (), empty (), HIDEN_SIZE , KV_HIDEN_SIZE , of (ACCURACY_LEVEL ), BITS , BLOCK_SIZE )),
119128 of (pastKey [i ]),
120129 of (pastValue [i ]),
121130 amSL ,
@@ -150,9 +159,7 @@ mlpDownScales[i], empty(), empty(), empty(), INTERMEDIATE_SIZE, HIDEN_SIZE, of(A
150159 presentValues [i ] = attn .present_value ();
151160 }
152161
153- Tensor <Float > logits = MatMulNBits (input ,
154- headWeight ,
155- headScales , empty (), empty (), empty (), HIDEN_SIZE , VOCAB_SIZE , of (ACCURACY_LEVEL ), BITS , BLOCK_SIZE );
162+ Tensor <Float > logits = MatMul (input , Transpose (tokensWeights , of (new long [] {1L , 0L })));
156163
157164 return new ForwardResponse (logits , presentKeys , presentValues );
158165 }
0 commit comments