diff --git a/src/modules/lda/lda.cpp b/src/modules/lda/lda.cpp index 0c9f26dee..c71b39b6f 100644 --- a/src/modules/lda/lda.cpp +++ b/src/modules/lda/lda.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -183,6 +184,7 @@ AnyType lda_gibbs_sample::run(AnyType & args) int32_t voc_size = args[6].getAs(); int32_t topic_num = args[7].getAs(); int32_t iter_num = args[8].getAs(); + size_t model64_size = static_cast(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t); if(alpha <= 0) throw std::invalid_argument("invalid argument - alpha"); @@ -221,13 +223,14 @@ AnyType lda_gibbs_sample::run(AnyType & args) if (!args.getUserFuncContext()) { ArrayHandle model64 = args[3].getAs >(); - if (model64.size() * sizeof(int64_t) / sizeof(int32_t) - != (size_t)(voc_size * (topic_num + 2))) { - throw std::invalid_argument( - "invalid dimension - model.size() != voc_size * (topic_num + 2)"); + if (model64.size() != model64_size) { + std::stringstream ss; + ss << "invalid dimension: model64.size() = " << model64.size(); + throw std::invalid_argument(ss.str()); } - if(__min(model64) < 0) + if (__min(model64) < 0) { throw std::invalid_argument("invalid topic counts in model"); + } int32_t *context = static_cast( @@ -239,10 +242,10 @@ AnyType lda_gibbs_sample::run(AnyType & args) int32_t *model = context; int64_t *running_topic_counts = reinterpret_cast( - context + model64.size() * sizeof(int64_t) / sizeof(int32_t)); + context + model64_size * sizeof(int64_t) / sizeof(int32_t)); for (int i = 0; i < voc_size; i ++) { for (int j = 0; j < topic_num; j ++) { - running_topic_counts[j] += model[i * (topic_num + 2) + j]; + running_topic_counts[j] += model[i * (topic_num + 1) + j]; } } @@ -250,12 +253,12 @@ AnyType lda_gibbs_sample::run(AnyType & args) } int32_t *context = static_cast(args.getUserFuncContext()); - if(context == NULL) { + if (context == NULL) { throw std::runtime_error("args.mSysInfo->user_fctx is null"); } int32_t *model = context; int64_t *running_topic_counts = reinterpret_cast( - context + voc_size * (topic_num + 2)); + context + model64_size * sizeof(int64_t) / sizeof(int32_t)); int32_t unique_word_count = static_cast(words.size()); for(int32_t it = 0; it < iter_num; it++){ @@ -266,20 +269,20 @@ AnyType lda_gibbs_sample::run(AnyType & args) int32_t topic = doc_topic[word_index]; int32_t retopic = __lda_gibbs_sample( topic_num, topic, doc_topic.ptr(), - model + wordid * (topic_num + 2), + model + wordid * (topic_num + 1), running_topic_counts, alpha, beta); doc_topic[word_index] = retopic; doc_topic[topic]--; doc_topic[retopic]++; if(iter_num == 1) { - if (model[wordid * (topic_num + 2) + retopic] <= 2e9) { + if (model[wordid * (topic_num + 1) + retopic] <= 2e9) { running_topic_counts[topic] --; running_topic_counts[retopic] ++; - model[wordid * (topic_num + 2) + topic]--; - model[wordid * (topic_num + 2) + retopic]++; + model[wordid * (topic_num + 1) + topic]--; + model[wordid * (topic_num + 1) + retopic]++; } else { - model[wordid * (topic_num + 2) + topic_num] = 1; + model[wordid * (topic_num + 1) + topic_num] = 1; } } word_index++; @@ -371,12 +374,19 @@ AnyType lda_count_topic_sfunc::run(AnyType & args) MutableArrayHandle state(NULL); int32_t *model; - if(args[0].isNull()){ - int dims[1] = {voc_size * (topic_num + 2) * sizeof(int32_t) / sizeof(int64_t)}; + if(args[0].isNull()) { + // to store a voc_size x (topic_num+1) integer matrix in + // bigint[] (the +1 is for a flag of ceiling the count), + // we need padding if the size is odd. + // 1. when voc_size * (topic_num + 1) is (2n+1), gives (n+1) + // 2. when voc_size * (topic_num + 1) is (2n), gives (n) + int dims[1] = {(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t)}; int lbs[1] = {1}; state = madlib_construct_md_array( NULL, NULL, 1, dims, lbs, INT8TI.oid, INT8TI.len, INT8TI.byval, INT8TI.align); + // the reason we use bigint[] because integer[] has limit on number of + // elements and thus cannot be larger than 500MB model = reinterpret_cast(state.ptr()); } else { state = args[0].getAs >(); @@ -389,10 +399,10 @@ AnyType lda_count_topic_sfunc::run(AnyType & args) int32_t wordid = words[i]; for(int32_t j = 0; j < counts[i]; j++){ int32_t topic = topic_assignment[word_index]; - if (model[wordid * (topic_num + 2) + topic] <= 2e9) { - model[wordid * (topic_num + 2) + topic]++; + if (model[wordid * (topic_num + 1) + topic] <= 2e9) { + model[wordid * (topic_num + 1) + topic]++; } else { - model[wordid * (topic_num + 2) + topic_num] = 1; + model[wordid * (topic_num + 1) + topic_num] = 1; } word_index++; } @@ -496,7 +506,7 @@ AnyType lda_unnest_transpose::SRF_next(void *user_fctx, bool *is_last_call) NULL, ctx->dim, INT4TI.oid, INT4TI.len, INT4TI.byval, INT4TI.align)); for (int i = 0; i < ctx->dim; i ++) { - outarray[i] = ctx->inarray[(ctx->maxcall + 2) * i + ctx->curcall]; + outarray[i] = ctx->inarray[(ctx->maxcall + 1) * i + ctx->curcall]; } ctx->curcall++; @@ -535,7 +545,7 @@ AnyType lda_unnest::SRF_next(void *user_fctx, bool *is_last_call) NULL, ctx->dim, INT4TI.oid, INT4TI.len, INT4TI.byval, INT4TI.align)); for (int i = 0; i < ctx->dim; i ++) { - outarray[i] = ctx->inarray[ctx->curcall * (ctx->dim + 2) + i]; + outarray[i] = ctx->inarray[ctx->curcall * (ctx->dim + 1) + i]; } ctx->curcall++; @@ -569,6 +579,7 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){ double beta = args[6].getAs(); int32_t voc_size = args[7].getAs(); int32_t topic_num = args[8].getAs(); + size_t model64_size = static_cast(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t); if(alpha <= 0) throw std::invalid_argument("invalid argument - alpha"); @@ -598,15 +609,17 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){ throw std::invalid_argument("invalid values in doc_topic_counts"); MutableArrayHandle state(NULL); - if(args[0].isNull()){ + if (args[0].isNull()) { ArrayHandle model64 = args[4].getAs >(); - if(model64.size() * sizeof(int64_t) / sizeof(int32_t) - != (size_t)(voc_size * (topic_num + 2))) - throw std::invalid_argument( - "invalid dimension - model.size() != voc_size * (topic_num + 2)"); - if(__min(model64) < 0) + if (model64.size() != model64_size) { + std::stringstream ss; + ss << "invalid dimension: model64.size() = " << model64.size(); + throw std::invalid_argument(ss.str()); + } + if(__min(model64) < 0) { throw std::invalid_argument("invalid topic counts in model"); + } state = madlib_construct_array(NULL, static_cast(model64.size()) @@ -619,11 +632,10 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){ memcpy(state.ptr(), model64.ptr(), model64.size() * sizeof(int64_t)); int32_t *model = reinterpret_cast(state.ptr()); - int64_t *total_topic_counts = reinterpret_cast(state.ptr() + model64.size()); for (int i = 0; i < voc_size; i ++) { for (int j = 0; j < topic_num; j ++) { - total_topic_counts[j] += model[i * (topic_num + 2) + j]; + total_topic_counts[j] += model[i * (topic_num + 1) + j]; } } }else{ @@ -631,8 +643,7 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){ } int32_t *model = reinterpret_cast(state.ptr()); - int64_t *total_topic_counts = reinterpret_cast( - state.ptr() + voc_size * (topic_num + 2) * sizeof(int32_t) / sizeof(int64_t)); + int64_t *total_topic_counts = reinterpret_cast(state.ptr() + model64_size); double *perp = reinterpret_cast(state.ptr() + state.size() - 1); int32_t n_d = 0; @@ -647,7 +658,7 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){ double sum_p = 0.0; for(int32_t z = 0; z < topic_num; z++){ int32_t n_dz = doc_topic_counts[z]; - int32_t n_wz = model[w * (topic_num + 2) + z]; + int32_t n_wz = model[w * (topic_num + 1) + z]; int64_t n_z = total_topic_counts[z]; sum_p += (static_cast(n_wz) + beta) * (n_dz + alpha) / (static_cast(n_z) + voc_size * beta); @@ -698,7 +709,7 @@ lda_check_count_ceiling::run(AnyType &args) { int count = 0; const int32_t *model = reinterpret_cast(model64.ptr()); for (int wordid = 0; wordid < voc_size; wordid ++) { - int flag = model[wordid * (topic_num + 2) + topic_num]; + int flag = model[wordid * (topic_num + 1) + topic_num]; if (flag != 0) { example_words_hit_ceiling[count ++] = wordid; } diff --git a/src/ports/postgres/modules/lda/lda.py_in b/src/ports/postgres/modules/lda/lda.py_in index a4ac200b8..b7560f133 100644 --- a/src/ports/postgres/modules/lda/lda.py_in +++ b/src/ports/postgres/modules/lda/lda.py_in @@ -1129,5 +1129,5 @@ def _validate_model_table(model_table): 'alpha in %s should be a positive real number' % (model_table)) _assert(rv[0]['beta'] > 0, 'beta in %s should be a positive real number' % (model_table)) - _assert(rv[0]['model_size'] * 2 == (rv[0]['voc_size']) * (rv[0]['topic_num'] + 2), + _assert(rv[0]['model_size'] == ((rv[0]['voc_size']) * (rv[0]['topic_num'] + 1) + 1) / 2, "model_size mismatches with voc_size and topic_num in %s" % (model_table)) diff --git a/src/ports/postgres/modules/lda/test/lda.sql_in b/src/ports/postgres/modules/lda/test/lda.sql_in index 121eadf5f..fd19c65ad 100644 --- a/src/ports/postgres/modules/lda/test/lda.sql_in +++ b/src/ports/postgres/modules/lda/test/lda.sql_in @@ -11,8 +11,7 @@ m4_include(`SQLCommon.m4') --------------------------------------------------------------------------- -- Build vocabulary: --------------------------------------------------------------------------- -CREATE TABLE lda_vocab(wordid INT4, word TEXT) -m4_ifdef(`__GREENPLUM__',`DISTRIBUTED BY (wordid)'); +CREATE TABLE lda_vocab(wordid INT4, word TEXT); INSERT INTO lda_vocab VALUES (0, 'code'), (1, 'data'), (2, 'graph'), (3, 'image'), (4, 'input'), (5, @@ -28,8 +27,7 @@ CREATE TABLE lda_training docid INT4, wordid INT4, count INT4 -) -m4_ifdef(`__GREENPLUM__',`DISTRIBUTED BY (docid)'); +); INSERT INTO lda_training VALUES (0, 0, 2),(0, 3, 2),(0, 5, 1),(0, 7, 1),(0, 8, 1),(0, 9, 1),(0, 11, 1),(0, 13, @@ -45,6 +43,27 @@ INSERT INTO lda_training VALUES (9, 0, 1),(9, 1, 1),(9, 4, 1),(9, 9, 2),(9, 12, 2),(9, 15, 1),(9, 18, 1),(9, 19, 1); +CREATE TABLE lda_training_odd_voc_size +( + docid INT4, + wordid INT4, + count INT4 +); + +INSERT INTO lda_training_odd_voc_size VALUES +(0, 0, 2),(0, 3, 2),(0, 5, 1),(0, 7, 1),(0, 8, 1),(0, 9, 1),(0, 11, 1),(0, 13, +1), (1, 0, 1),(1, 3, 1),(1, 4, 1),(1, 5, 1),(1, 6, 1),(1, 7, 1),(1, 10, 1),(1, +14, 1),(1, 17, 1),(1, 18, 1), (2, 4, 2),(2, 5, 1),(2, 6, 2),(2, 12, 1),(2, 13, +1),(2, 15, 1),(2, 18, 2), (3, 0, 1),(3, 1, 2),(3, 12, 3),(3, 16, 1),(3, 17, +2),(3, 19, 1), (4, 1, 1),(4, 2, 1),(4, 3, 1),(4, 5, 1),(4, 6, 1),(4, 10, 1),(4, +11, 1),(4, 14, 1),(4, 18, 1),(4, 19, 1), (5, 0, 1),(5, 2, 1),(5, 5, 1),(5, 7, +1),(5, 10, 1),(5, 12, 1),(5, 16, 1),(5, 18, 1),(5, 19, 2), (6, 1, 1),(6, 3, +1),(6, 12, 2),(6, 13, 1),(6, 14, 2),(6, 15, 1),(6, 16, 1),(6, 17, 1), (7, 0, +1),(7, 2, 1),(7, 4, 1),(7, 5, 1),(7, 7, 2),(7, 8, 1),(7, 11, 1),(7, 14, 1),(7, +16, 1), (8, 2, 1),(8, 4, 4),(8, 6, 2),(8, 11, 1),(8, 15, 1),(8, 18, 1), +(9, 0, 1),(9, 1, 1),(9, 4, 1),(9, 9, 2),(9, 12, 2),(9, 15, 1),(9, 18, 1),(9, +19, 1),(9, 20, 1); + CREATE TABLE lda_testing ( @@ -104,20 +123,20 @@ SELECT __lda_util_norm_with_smoothing(array[1, 4, 2, 3], 0.1); SELECT assert( __lda_check_count_ceiling( - array[1, 0, 4, 0]::bigint[], + array[0, 0, 0, 0]::bigint[], 2, 2) IS NULL, - '__lda_check_count_ceiling should return NULL for [1, 0, 4, 0]'); + '__lda_check_count_ceiling should return NULL for [0, 0, 0, 0]'); -- length: 1 SELECT assert( __lda_check_count_ceiling( - array[1, -1, 4, 0]::bigint[], + array[-1, -1, -1, -1]::bigint[], 2, 2) IS NOT NULL, - '__lda_check_count_ceiling should not return NULL for [1, -1, 4, 0]'); + '__lda_check_count_ceiling should not return NULL for [-1, -1, -1, -1]'); SELECT lda_get_topic_desc( 'lda_model', @@ -143,3 +162,22 @@ FROM __lda_util_norm_dataset('lda_testing', 'norm_lda_vocab', SELECT * FROM __lda_util_conorm_data('lda_testing', 'lda_vocab', 'norm_lda_data_2', 'norm_lda_vocab_2'); + +-- both voc_size and topic_num are odd or even +SELECT lda_train( + 'lda_training_odd_voc_size', + 'lda_model_odd_voc_size_even_topic_num', + 'lda_output_odd_voc_size_even_topic_num', + 21, 6, 2, 3, 0.01); + +SELECT lda_train( + 'lda_training_odd_voc_size', + 'lda_model_odd_voc_size_odd_topic_num', + 'lda_output_odd_voc_size_odd_topic_num', + 21, 5, 2, 3, 0.01); + +SELECT lda_train( + 'lda_training', + 'lda_model_even_voc_size_even_topic_num', + 'lda_output_even_voc_size_even_topic_num', + 20, 6, 2, 3, 0.01);