Skip to content

Commit bf083f5

Browse files
author
Feng, Xixuan (Aaron)
committed
LDA: fix odd/even voc_size & topic_num, memory alignment
Pivotal Tracker: #100046648 Changes: - add tests in install-check to cover all (odd/even) x (odd/even) cases - fix the padding for LDA model storing integer matrix in bigint[] always allocate ((voc_size * (topic_num + 1) + 1) / 2) number of bigint
1 parent 5b57779 commit bf083f5

File tree

3 files changed

+91
-42
lines changed

3 files changed

+91
-42
lines changed

src/modules/lda/lda.cpp

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <dbconnector/dbconnector.hpp>
1212
#include <math.h>
1313
#include <iostream>
14+
#include <sstream>
1415
#include <algorithm>
1516
#include <functional>
1617
#include <numeric>
@@ -183,6 +184,7 @@ AnyType lda_gibbs_sample::run(AnyType & args)
183184
int32_t voc_size = args[6].getAs<int32_t>();
184185
int32_t topic_num = args[7].getAs<int32_t>();
185186
int32_t iter_num = args[8].getAs<int32_t>();
187+
size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);
186188

187189
if(alpha <= 0)
188190
throw std::invalid_argument("invalid argument - alpha");
@@ -221,13 +223,14 @@ AnyType lda_gibbs_sample::run(AnyType & args)
221223

222224
if (!args.getUserFuncContext()) {
223225
ArrayHandle<int64_t> model64 = args[3].getAs<ArrayHandle<int64_t> >();
224-
if (model64.size() * sizeof(int64_t) / sizeof(int32_t)
225-
!= (size_t)(voc_size * (topic_num + 2))) {
226-
throw std::invalid_argument(
227-
"invalid dimension - model.size() != voc_size * (topic_num + 2)");
226+
if (model64.size() != model64_size) {
227+
std::stringstream ss;
228+
ss << "invalid dimension: model64.size() = " << model64.size();
229+
throw std::invalid_argument(ss.str());
228230
}
229-
if(__min(model64) < 0)
231+
if (__min(model64) < 0) {
230232
throw std::invalid_argument("invalid topic counts in model");
233+
}
231234

232235
int32_t *context =
233236
static_cast<int32_t *>(
@@ -239,23 +242,23 @@ AnyType lda_gibbs_sample::run(AnyType & args)
239242
int32_t *model = context;
240243

241244
int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
242-
context + model64.size() * sizeof(int64_t) / sizeof(int32_t));
245+
context + model64_size * sizeof(int64_t) / sizeof(int32_t));
243246
for (int i = 0; i < voc_size; i ++) {
244247
for (int j = 0; j < topic_num; j ++) {
245-
running_topic_counts[j] += model[i * (topic_num + 2) + j];
248+
running_topic_counts[j] += model[i * (topic_num + 1) + j];
246249
}
247250
}
248251

249252
args.setUserFuncContext(context);
250253
}
251254

252255
int32_t *context = static_cast<int32_t *>(args.getUserFuncContext());
253-
if(context == NULL) {
256+
if (context == NULL) {
254257
throw std::runtime_error("args.mSysInfo->user_fctx is null");
255258
}
256259
int32_t *model = context;
257260
int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
258-
context + voc_size * (topic_num + 2));
261+
context + model64_size * sizeof(int64_t) / sizeof(int32_t));
259262

260263
int32_t unique_word_count = static_cast<int32_t>(words.size());
261264
for(int32_t it = 0; it < iter_num; it++){
@@ -266,20 +269,20 @@ AnyType lda_gibbs_sample::run(AnyType & args)
266269
int32_t topic = doc_topic[word_index];
267270
int32_t retopic = __lda_gibbs_sample(
268271
topic_num, topic, doc_topic.ptr(),
269-
model + wordid * (topic_num + 2),
272+
model + wordid * (topic_num + 1),
270273
running_topic_counts, alpha, beta);
271274
doc_topic[word_index] = retopic;
272275
doc_topic[topic]--;
273276
doc_topic[retopic]++;
274277

275278
if(iter_num == 1) {
276-
if (model[wordid * (topic_num + 2) + retopic] <= 2e9) {
279+
if (model[wordid * (topic_num + 1) + retopic] <= 2e9) {
277280
running_topic_counts[topic] --;
278281
running_topic_counts[retopic] ++;
279-
model[wordid * (topic_num + 2) + topic]--;
280-
model[wordid * (topic_num + 2) + retopic]++;
282+
model[wordid * (topic_num + 1) + topic]--;
283+
model[wordid * (topic_num + 1) + retopic]++;
281284
} else {
282-
model[wordid * (topic_num + 2) + topic_num] = 1;
285+
model[wordid * (topic_num + 1) + topic_num] = 1;
283286
}
284287
}
285288
word_index++;
@@ -371,12 +374,19 @@ AnyType lda_count_topic_sfunc::run(AnyType & args)
371374

372375
MutableArrayHandle<int64_t> state(NULL);
373376
int32_t *model;
374-
if(args[0].isNull()){
375-
int dims[1] = {voc_size * (topic_num + 2) * sizeof(int32_t) / sizeof(int64_t)};
377+
if(args[0].isNull()) {
378+
// to store a voc_size x (topic_num+1) integer matrix in
379+
// bigint[] (the +1 is for a flag of ceiling the count),
380+
// we need padding if the size is odd.
381+
// 1. when voc_size * (topic_num + 1) is (2n+1), gives (n+1)
382+
// 2. when voc_size * (topic_num + 1) is (2n), gives (n)
383+
int dims[1] = {(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t)};
376384
int lbs[1] = {1};
377385
state = madlib_construct_md_array(
378386
NULL, NULL, 1, dims, lbs, INT8TI.oid, INT8TI.len, INT8TI.byval,
379387
INT8TI.align);
388+
// the reason we use bigint[] because integer[] has limit on number of
389+
// elements and thus cannot be larger than 500MB
380390
model = reinterpret_cast<int32_t *>(state.ptr());
381391
} else {
382392
state = args[0].getAs<MutableArrayHandle<int64_t> >();
@@ -389,10 +399,10 @@ AnyType lda_count_topic_sfunc::run(AnyType & args)
389399
int32_t wordid = words[i];
390400
for(int32_t j = 0; j < counts[i]; j++){
391401
int32_t topic = topic_assignment[word_index];
392-
if (model[wordid * (topic_num + 2) + topic] <= 2e9) {
393-
model[wordid * (topic_num + 2) + topic]++;
402+
if (model[wordid * (topic_num + 1) + topic] <= 2e9) {
403+
model[wordid * (topic_num + 1) + topic]++;
394404
} else {
395-
model[wordid * (topic_num + 2) + topic_num] = 1;
405+
model[wordid * (topic_num + 1) + topic_num] = 1;
396406
}
397407
word_index++;
398408
}
@@ -496,7 +506,7 @@ AnyType lda_unnest_transpose::SRF_next(void *user_fctx, bool *is_last_call)
496506
NULL, ctx->dim, INT4TI.oid, INT4TI.len, INT4TI.byval,
497507
INT4TI.align));
498508
for (int i = 0; i < ctx->dim; i ++) {
499-
outarray[i] = ctx->inarray[(ctx->maxcall + 2) * i + ctx->curcall];
509+
outarray[i] = ctx->inarray[(ctx->maxcall + 1) * i + ctx->curcall];
500510
}
501511

502512
ctx->curcall++;
@@ -535,7 +545,7 @@ AnyType lda_unnest::SRF_next(void *user_fctx, bool *is_last_call)
535545
NULL, ctx->dim, INT4TI.oid, INT4TI.len, INT4TI.byval,
536546
INT4TI.align));
537547
for (int i = 0; i < ctx->dim; i ++) {
538-
outarray[i] = ctx->inarray[ctx->curcall * (ctx->dim + 2) + i];
548+
outarray[i] = ctx->inarray[ctx->curcall * (ctx->dim + 1) + i];
539549
}
540550

541551
ctx->curcall++;
@@ -569,6 +579,7 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){
569579
double beta = args[6].getAs<double>();
570580
int32_t voc_size = args[7].getAs<int32_t>();
571581
int32_t topic_num = args[8].getAs<int32_t>();
582+
size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);
572583

573584
if(alpha <= 0)
574585
throw std::invalid_argument("invalid argument - alpha");
@@ -598,15 +609,17 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){
598609
throw std::invalid_argument("invalid values in doc_topic_counts");
599610

600611
MutableArrayHandle<int64_t> state(NULL);
601-
if(args[0].isNull()){
612+
if (args[0].isNull()) {
602613
ArrayHandle<int64_t> model64 = args[4].getAs<ArrayHandle<int64_t> >();
603614

604-
if(model64.size() * sizeof(int64_t) / sizeof(int32_t)
605-
!= (size_t)(voc_size * (topic_num + 2)))
606-
throw std::invalid_argument(
607-
"invalid dimension - model.size() != voc_size * (topic_num + 2)");
608-
if(__min(model64) < 0)
615+
if (model64.size() != model64_size) {
616+
std::stringstream ss;
617+
ss << "invalid dimension: model64.size() = " << model64.size();
618+
throw std::invalid_argument(ss.str());
619+
}
620+
if(__min(model64) < 0) {
609621
throw std::invalid_argument("invalid topic counts in model");
622+
}
610623

611624
state = madlib_construct_array(NULL,
612625
static_cast<int>(model64.size())
@@ -619,20 +632,18 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){
619632

620633
memcpy(state.ptr(), model64.ptr(), model64.size() * sizeof(int64_t));
621634
int32_t *model = reinterpret_cast<int32_t *>(state.ptr());
622-
623635
int64_t *total_topic_counts = reinterpret_cast<int64_t *>(state.ptr() + model64.size());
624636
for (int i = 0; i < voc_size; i ++) {
625637
for (int j = 0; j < topic_num; j ++) {
626-
total_topic_counts[j] += model[i * (topic_num + 2) + j];
638+
total_topic_counts[j] += model[i * (topic_num + 1) + j];
627639
}
628640
}
629641
}else{
630642
state = args[0].getAs<MutableArrayHandle<int64_t> >();
631643
}
632644

633645
int32_t *model = reinterpret_cast<int32_t *>(state.ptr());
634-
int64_t *total_topic_counts = reinterpret_cast<int64_t *>(
635-
state.ptr() + voc_size * (topic_num + 2) * sizeof(int32_t) / sizeof(int64_t));
646+
int64_t *total_topic_counts = reinterpret_cast<int64_t *>(state.ptr() + model64_size);
636647
double *perp = reinterpret_cast<double *>(state.ptr() + state.size() - 1);
637648

638649
int32_t n_d = 0;
@@ -647,7 +658,7 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){
647658
double sum_p = 0.0;
648659
for(int32_t z = 0; z < topic_num; z++){
649660
int32_t n_dz = doc_topic_counts[z];
650-
int32_t n_wz = model[w * (topic_num + 2) + z];
661+
int32_t n_wz = model[w * (topic_num + 1) + z];
651662
int64_t n_z = total_topic_counts[z];
652663
sum_p += (static_cast<double>(n_wz) + beta) * (n_dz + alpha)
653664
/ (static_cast<double>(n_z) + voc_size * beta);
@@ -698,7 +709,7 @@ lda_check_count_ceiling::run(AnyType &args) {
698709
int count = 0;
699710
const int32_t *model = reinterpret_cast<const int32_t *>(model64.ptr());
700711
for (int wordid = 0; wordid < voc_size; wordid ++) {
701-
int flag = model[wordid * (topic_num + 2) + topic_num];
712+
int flag = model[wordid * (topic_num + 1) + topic_num];
702713
if (flag != 0) {
703714
example_words_hit_ceiling[count ++] = wordid;
704715
}

src/ports/postgres/modules/lda/lda.py_in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1129,5 +1129,5 @@ def _validate_model_table(model_table):
11291129
'alpha in %s should be a positive real number' % (model_table))
11301130
_assert(rv[0]['beta'] > 0,
11311131
'beta in %s should be a positive real number' % (model_table))
1132-
_assert(rv[0]['model_size'] * 2 == (rv[0]['voc_size']) * (rv[0]['topic_num'] + 2),
1132+
_assert(rv[0]['model_size'] == ((rv[0]['voc_size']) * (rv[0]['topic_num'] + 1) + 1) / 2,
11331133
"model_size mismatches with voc_size and topic_num in %s" % (model_table))

src/ports/postgres/modules/lda/test/lda.sql_in

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ m4_include(`SQLCommon.m4')
1111
---------------------------------------------------------------------------
1212
-- Build vocabulary:
1313
---------------------------------------------------------------------------
14-
CREATE TABLE lda_vocab(wordid INT4, word TEXT)
15-
m4_ifdef(`__GREENPLUM__',`DISTRIBUTED BY (wordid)');
14+
CREATE TABLE lda_vocab(wordid INT4, word TEXT);
1615

1716
INSERT INTO lda_vocab VALUES
1817
(0, 'code'), (1, 'data'), (2, 'graph'), (3, 'image'), (4, 'input'), (5,
@@ -28,8 +27,7 @@ CREATE TABLE lda_training
2827
docid INT4,
2928
wordid INT4,
3029
count INT4
31-
)
32-
m4_ifdef(`__GREENPLUM__',`DISTRIBUTED BY (docid)');
30+
);
3331

3432
INSERT INTO lda_training VALUES
3533
(0, 0, 2),(0, 3, 2),(0, 5, 1),(0, 7, 1),(0, 8, 1),(0, 9, 1),(0, 11, 1),(0, 13,
@@ -45,6 +43,27 @@ INSERT INTO lda_training VALUES
4543
(9, 0, 1),(9, 1, 1),(9, 4, 1),(9, 9, 2),(9, 12, 2),(9, 15, 1),(9, 18, 1),(9,
4644
19, 1);
4745

46+
CREATE TABLE lda_training_odd_voc_size
47+
(
48+
docid INT4,
49+
wordid INT4,
50+
count INT4
51+
);
52+
53+
INSERT INTO lda_training_odd_voc_size VALUES
54+
(0, 0, 2),(0, 3, 2),(0, 5, 1),(0, 7, 1),(0, 8, 1),(0, 9, 1),(0, 11, 1),(0, 13,
55+
1), (1, 0, 1),(1, 3, 1),(1, 4, 1),(1, 5, 1),(1, 6, 1),(1, 7, 1),(1, 10, 1),(1,
56+
14, 1),(1, 17, 1),(1, 18, 1), (2, 4, 2),(2, 5, 1),(2, 6, 2),(2, 12, 1),(2, 13,
57+
1),(2, 15, 1),(2, 18, 2), (3, 0, 1),(3, 1, 2),(3, 12, 3),(3, 16, 1),(3, 17,
58+
2),(3, 19, 1), (4, 1, 1),(4, 2, 1),(4, 3, 1),(4, 5, 1),(4, 6, 1),(4, 10, 1),(4,
59+
11, 1),(4, 14, 1),(4, 18, 1),(4, 19, 1), (5, 0, 1),(5, 2, 1),(5, 5, 1),(5, 7,
60+
1),(5, 10, 1),(5, 12, 1),(5, 16, 1),(5, 18, 1),(5, 19, 2), (6, 1, 1),(6, 3,
61+
1),(6, 12, 2),(6, 13, 1),(6, 14, 2),(6, 15, 1),(6, 16, 1),(6, 17, 1), (7, 0,
62+
1),(7, 2, 1),(7, 4, 1),(7, 5, 1),(7, 7, 2),(7, 8, 1),(7, 11, 1),(7, 14, 1),(7,
63+
16, 1), (8, 2, 1),(8, 4, 4),(8, 6, 2),(8, 11, 1),(8, 15, 1),(8, 18, 1),
64+
(9, 0, 1),(9, 1, 1),(9, 4, 1),(9, 9, 2),(9, 12, 2),(9, 15, 1),(9, 18, 1),(9,
65+
19, 1),(9, 20, 1);
66+
4867

4968
CREATE TABLE lda_testing
5069
(
@@ -104,20 +123,20 @@ SELECT __lda_util_norm_with_smoothing(array[1, 4, 2, 3], 0.1);
104123
SELECT
105124
assert(
106125
__lda_check_count_ceiling(
107-
array[1, 0, 4, 0]::bigint[],
126+
array[0, 0, 0, 0]::bigint[],
108127
2,
109128
2)
110129
IS NULL,
111-
'__lda_check_count_ceiling should return NULL for [1, 0, 4, 0]');
130+
'__lda_check_count_ceiling should return NULL for [0, 0, 0, 0]');
112131
-- length: 1
113132
SELECT
114133
assert(
115134
__lda_check_count_ceiling(
116-
array[1, -1, 4, 0]::bigint[],
135+
array[-1, -1, -1, -1]::bigint[],
117136
2,
118137
2)
119138
IS NOT NULL,
120-
'__lda_check_count_ceiling should not return NULL for [1, -1, 4, 0]');
139+
'__lda_check_count_ceiling should not return NULL for [-1, -1, -1, -1]');
121140

122141
SELECT lda_get_topic_desc(
123142
'lda_model',
@@ -143,3 +162,22 @@ FROM __lda_util_norm_dataset('lda_testing', 'norm_lda_vocab',
143162
SELECT *
144163
FROM __lda_util_conorm_data('lda_testing', 'lda_vocab',
145164
'norm_lda_data_2', 'norm_lda_vocab_2');
165+
166+
-- both voc_size and topic_num are odd or even
167+
SELECT lda_train(
168+
'lda_training_odd_voc_size',
169+
'lda_model_odd_voc_size_even_topic_num',
170+
'lda_output_odd_voc_size_even_topic_num',
171+
21, 6, 2, 3, 0.01);
172+
173+
SELECT lda_train(
174+
'lda_training_odd_voc_size',
175+
'lda_model_odd_voc_size_odd_topic_num',
176+
'lda_output_odd_voc_size_odd_topic_num',
177+
21, 5, 2, 3, 0.01);
178+
179+
SELECT lda_train(
180+
'lda_training',
181+
'lda_model_even_voc_size_even_topic_num',
182+
'lda_output_even_voc_size_even_topic_num',
183+
20, 6, 2, 3, 0.01);

0 commit comments

Comments
 (0)