Skip to content

Commit

Permalink
LDA: fix odd/even voc_size & topic_num, memory alignment
Browse files Browse the repository at this point in the history
Pivotal Tracker: #100046648
Changes:
- add tests in install-check to cover all (odd/even) x (odd/even) cases
- fix the padding for LDA model storing integer matrix in bigint[]
  always allocate ((voc_size * (topic_num + 1) + 1) / 2) number of bigint
  • Loading branch information
Feng, Xixuan (Aaron) committed Jul 29, 2015
1 parent 5b57779 commit bf083f5
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 42 deletions.
77 changes: 44 additions & 33 deletions src/modules/lda/lda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <dbconnector/dbconnector.hpp>
#include <math.h>
#include <iostream>
#include <sstream>
#include <algorithm>
#include <functional>
#include <numeric>
Expand Down Expand Up @@ -183,6 +184,7 @@ AnyType lda_gibbs_sample::run(AnyType & args)
int32_t voc_size = args[6].getAs<int32_t>();
int32_t topic_num = args[7].getAs<int32_t>();
int32_t iter_num = args[8].getAs<int32_t>();
size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);

if(alpha <= 0)
throw std::invalid_argument("invalid argument - alpha");
Expand Down Expand Up @@ -221,13 +223,14 @@ AnyType lda_gibbs_sample::run(AnyType & args)

if (!args.getUserFuncContext()) {
ArrayHandle<int64_t> model64 = args[3].getAs<ArrayHandle<int64_t> >();
if (model64.size() * sizeof(int64_t) / sizeof(int32_t)
!= (size_t)(voc_size * (topic_num + 2))) {
throw std::invalid_argument(
"invalid dimension - model.size() != voc_size * (topic_num + 2)");
if (model64.size() != model64_size) {
std::stringstream ss;
ss << "invalid dimension: model64.size() = " << model64.size();
throw std::invalid_argument(ss.str());
}
if(__min(model64) < 0)
if (__min(model64) < 0) {
throw std::invalid_argument("invalid topic counts in model");
}

int32_t *context =
static_cast<int32_t *>(
Expand All @@ -239,23 +242,23 @@ AnyType lda_gibbs_sample::run(AnyType & args)
int32_t *model = context;

int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
context + model64.size() * sizeof(int64_t) / sizeof(int32_t));
context + model64_size * sizeof(int64_t) / sizeof(int32_t));
for (int i = 0; i < voc_size; i ++) {
for (int j = 0; j < topic_num; j ++) {
running_topic_counts[j] += model[i * (topic_num + 2) + j];
running_topic_counts[j] += model[i * (topic_num + 1) + j];
}
}

args.setUserFuncContext(context);
}

int32_t *context = static_cast<int32_t *>(args.getUserFuncContext());
if(context == NULL) {
if (context == NULL) {
throw std::runtime_error("args.mSysInfo->user_fctx is null");
}
int32_t *model = context;
int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
context + voc_size * (topic_num + 2));
context + model64_size * sizeof(int64_t) / sizeof(int32_t));

int32_t unique_word_count = static_cast<int32_t>(words.size());
for(int32_t it = 0; it < iter_num; it++){
Expand All @@ -266,20 +269,20 @@ AnyType lda_gibbs_sample::run(AnyType & args)
int32_t topic = doc_topic[word_index];
int32_t retopic = __lda_gibbs_sample(
topic_num, topic, doc_topic.ptr(),
model + wordid * (topic_num + 2),
model + wordid * (topic_num + 1),
running_topic_counts, alpha, beta);
doc_topic[word_index] = retopic;
doc_topic[topic]--;
doc_topic[retopic]++;

if(iter_num == 1) {
if (model[wordid * (topic_num + 2) + retopic] <= 2e9) {
if (model[wordid * (topic_num + 1) + retopic] <= 2e9) {
running_topic_counts[topic] --;
running_topic_counts[retopic] ++;
model[wordid * (topic_num + 2) + topic]--;
model[wordid * (topic_num + 2) + retopic]++;
model[wordid * (topic_num + 1) + topic]--;
model[wordid * (topic_num + 1) + retopic]++;
} else {
model[wordid * (topic_num + 2) + topic_num] = 1;
model[wordid * (topic_num + 1) + topic_num] = 1;
}
}
word_index++;
Expand Down Expand Up @@ -371,12 +374,19 @@ AnyType lda_count_topic_sfunc::run(AnyType & args)

MutableArrayHandle<int64_t> state(NULL);
int32_t *model;
if(args[0].isNull()){
int dims[1] = {voc_size * (topic_num + 2) * sizeof(int32_t) / sizeof(int64_t)};
if(args[0].isNull()) {
// to store a voc_size x (topic_num+1) integer matrix in
// bigint[] (the +1 is for a flag of ceiling the count),
// we need padding if the size is odd.
// 1. when voc_size * (topic_num + 1) is (2n+1), gives (n+1)
// 2. when voc_size * (topic_num + 1) is (2n), gives (n)
int dims[1] = {(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t)};
int lbs[1] = {1};
state = madlib_construct_md_array(
NULL, NULL, 1, dims, lbs, INT8TI.oid, INT8TI.len, INT8TI.byval,
INT8TI.align);
// the reason we use bigint[] because integer[] has limit on number of
// elements and thus cannot be larger than 500MB
model = reinterpret_cast<int32_t *>(state.ptr());
} else {
state = args[0].getAs<MutableArrayHandle<int64_t> >();
Expand All @@ -389,10 +399,10 @@ AnyType lda_count_topic_sfunc::run(AnyType & args)
int32_t wordid = words[i];
for(int32_t j = 0; j < counts[i]; j++){
int32_t topic = topic_assignment[word_index];
if (model[wordid * (topic_num + 2) + topic] <= 2e9) {
model[wordid * (topic_num + 2) + topic]++;
if (model[wordid * (topic_num + 1) + topic] <= 2e9) {
model[wordid * (topic_num + 1) + topic]++;
} else {
model[wordid * (topic_num + 2) + topic_num] = 1;
model[wordid * (topic_num + 1) + topic_num] = 1;
}
word_index++;
}
Expand Down Expand Up @@ -496,7 +506,7 @@ AnyType lda_unnest_transpose::SRF_next(void *user_fctx, bool *is_last_call)
NULL, ctx->dim, INT4TI.oid, INT4TI.len, INT4TI.byval,
INT4TI.align));
for (int i = 0; i < ctx->dim; i ++) {
outarray[i] = ctx->inarray[(ctx->maxcall + 2) * i + ctx->curcall];
outarray[i] = ctx->inarray[(ctx->maxcall + 1) * i + ctx->curcall];
}

ctx->curcall++;
Expand Down Expand Up @@ -535,7 +545,7 @@ AnyType lda_unnest::SRF_next(void *user_fctx, bool *is_last_call)
NULL, ctx->dim, INT4TI.oid, INT4TI.len, INT4TI.byval,
INT4TI.align));
for (int i = 0; i < ctx->dim; i ++) {
outarray[i] = ctx->inarray[ctx->curcall * (ctx->dim + 2) + i];
outarray[i] = ctx->inarray[ctx->curcall * (ctx->dim + 1) + i];
}

ctx->curcall++;
Expand Down Expand Up @@ -569,6 +579,7 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){
double beta = args[6].getAs<double>();
int32_t voc_size = args[7].getAs<int32_t>();
int32_t topic_num = args[8].getAs<int32_t>();
size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);

if(alpha <= 0)
throw std::invalid_argument("invalid argument - alpha");
Expand Down Expand Up @@ -598,15 +609,17 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){
throw std::invalid_argument("invalid values in doc_topic_counts");

MutableArrayHandle<int64_t> state(NULL);
if(args[0].isNull()){
if (args[0].isNull()) {
ArrayHandle<int64_t> model64 = args[4].getAs<ArrayHandle<int64_t> >();

if(model64.size() * sizeof(int64_t) / sizeof(int32_t)
!= (size_t)(voc_size * (topic_num + 2)))
throw std::invalid_argument(
"invalid dimension - model.size() != voc_size * (topic_num + 2)");
if(__min(model64) < 0)
if (model64.size() != model64_size) {
std::stringstream ss;
ss << "invalid dimension: model64.size() = " << model64.size();
throw std::invalid_argument(ss.str());
}
if(__min(model64) < 0) {
throw std::invalid_argument("invalid topic counts in model");
}

state = madlib_construct_array(NULL,
static_cast<int>(model64.size())
Expand All @@ -619,20 +632,18 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){

memcpy(state.ptr(), model64.ptr(), model64.size() * sizeof(int64_t));
int32_t *model = reinterpret_cast<int32_t *>(state.ptr());

int64_t *total_topic_counts = reinterpret_cast<int64_t *>(state.ptr() + model64.size());
for (int i = 0; i < voc_size; i ++) {
for (int j = 0; j < topic_num; j ++) {
total_topic_counts[j] += model[i * (topic_num + 2) + j];
total_topic_counts[j] += model[i * (topic_num + 1) + j];
}
}
}else{
state = args[0].getAs<MutableArrayHandle<int64_t> >();
}

int32_t *model = reinterpret_cast<int32_t *>(state.ptr());
int64_t *total_topic_counts = reinterpret_cast<int64_t *>(
state.ptr() + voc_size * (topic_num + 2) * sizeof(int32_t) / sizeof(int64_t));
int64_t *total_topic_counts = reinterpret_cast<int64_t *>(state.ptr() + model64_size);
double *perp = reinterpret_cast<double *>(state.ptr() + state.size() - 1);

int32_t n_d = 0;
Expand All @@ -647,7 +658,7 @@ AnyType lda_perplexity_sfunc::run(AnyType & args){
double sum_p = 0.0;
for(int32_t z = 0; z < topic_num; z++){
int32_t n_dz = doc_topic_counts[z];
int32_t n_wz = model[w * (topic_num + 2) + z];
int32_t n_wz = model[w * (topic_num + 1) + z];
int64_t n_z = total_topic_counts[z];
sum_p += (static_cast<double>(n_wz) + beta) * (n_dz + alpha)
/ (static_cast<double>(n_z) + voc_size * beta);
Expand Down Expand Up @@ -698,7 +709,7 @@ lda_check_count_ceiling::run(AnyType &args) {
int count = 0;
const int32_t *model = reinterpret_cast<const int32_t *>(model64.ptr());
for (int wordid = 0; wordid < voc_size; wordid ++) {
int flag = model[wordid * (topic_num + 2) + topic_num];
int flag = model[wordid * (topic_num + 1) + topic_num];
if (flag != 0) {
example_words_hit_ceiling[count ++] = wordid;
}
Expand Down
2 changes: 1 addition & 1 deletion src/ports/postgres/modules/lda/lda.py_in
Original file line number Diff line number Diff line change
Expand Up @@ -1129,5 +1129,5 @@ def _validate_model_table(model_table):
'alpha in %s should be a positive real number' % (model_table))
_assert(rv[0]['beta'] > 0,
'beta in %s should be a positive real number' % (model_table))
_assert(rv[0]['model_size'] * 2 == (rv[0]['voc_size']) * (rv[0]['topic_num'] + 2),
_assert(rv[0]['model_size'] == ((rv[0]['voc_size']) * (rv[0]['topic_num'] + 1) + 1) / 2,
"model_size mismatches with voc_size and topic_num in %s" % (model_table))
54 changes: 46 additions & 8 deletions src/ports/postgres/modules/lda/test/lda.sql_in
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ m4_include(`SQLCommon.m4')
---------------------------------------------------------------------------
-- Build vocabulary:
---------------------------------------------------------------------------
CREATE TABLE lda_vocab(wordid INT4, word TEXT)
m4_ifdef(`__GREENPLUM__',`DISTRIBUTED BY (wordid)');
CREATE TABLE lda_vocab(wordid INT4, word TEXT);

INSERT INTO lda_vocab VALUES
(0, 'code'), (1, 'data'), (2, 'graph'), (3, 'image'), (4, 'input'), (5,
Expand All @@ -28,8 +27,7 @@ CREATE TABLE lda_training
docid INT4,
wordid INT4,
count INT4
)
m4_ifdef(`__GREENPLUM__',`DISTRIBUTED BY (docid)');
);

INSERT INTO lda_training VALUES
(0, 0, 2),(0, 3, 2),(0, 5, 1),(0, 7, 1),(0, 8, 1),(0, 9, 1),(0, 11, 1),(0, 13,
Expand All @@ -45,6 +43,27 @@ INSERT INTO lda_training VALUES
(9, 0, 1),(9, 1, 1),(9, 4, 1),(9, 9, 2),(9, 12, 2),(9, 15, 1),(9, 18, 1),(9,
19, 1);

CREATE TABLE lda_training_odd_voc_size
(
docid INT4,
wordid INT4,
count INT4
);

INSERT INTO lda_training_odd_voc_size VALUES
(0, 0, 2),(0, 3, 2),(0, 5, 1),(0, 7, 1),(0, 8, 1),(0, 9, 1),(0, 11, 1),(0, 13,
1), (1, 0, 1),(1, 3, 1),(1, 4, 1),(1, 5, 1),(1, 6, 1),(1, 7, 1),(1, 10, 1),(1,
14, 1),(1, 17, 1),(1, 18, 1), (2, 4, 2),(2, 5, 1),(2, 6, 2),(2, 12, 1),(2, 13,
1),(2, 15, 1),(2, 18, 2), (3, 0, 1),(3, 1, 2),(3, 12, 3),(3, 16, 1),(3, 17,
2),(3, 19, 1), (4, 1, 1),(4, 2, 1),(4, 3, 1),(4, 5, 1),(4, 6, 1),(4, 10, 1),(4,
11, 1),(4, 14, 1),(4, 18, 1),(4, 19, 1), (5, 0, 1),(5, 2, 1),(5, 5, 1),(5, 7,
1),(5, 10, 1),(5, 12, 1),(5, 16, 1),(5, 18, 1),(5, 19, 2), (6, 1, 1),(6, 3,
1),(6, 12, 2),(6, 13, 1),(6, 14, 2),(6, 15, 1),(6, 16, 1),(6, 17, 1), (7, 0,
1),(7, 2, 1),(7, 4, 1),(7, 5, 1),(7, 7, 2),(7, 8, 1),(7, 11, 1),(7, 14, 1),(7,
16, 1), (8, 2, 1),(8, 4, 4),(8, 6, 2),(8, 11, 1),(8, 15, 1),(8, 18, 1),
(9, 0, 1),(9, 1, 1),(9, 4, 1),(9, 9, 2),(9, 12, 2),(9, 15, 1),(9, 18, 1),(9,
19, 1),(9, 20, 1);


CREATE TABLE lda_testing
(
Expand Down Expand Up @@ -104,20 +123,20 @@ SELECT __lda_util_norm_with_smoothing(array[1, 4, 2, 3], 0.1);
SELECT
assert(
__lda_check_count_ceiling(
array[1, 0, 4, 0]::bigint[],
array[0, 0, 0, 0]::bigint[],
2,
2)
IS NULL,
'__lda_check_count_ceiling should return NULL for [1, 0, 4, 0]');
'__lda_check_count_ceiling should return NULL for [0, 0, 0, 0]');
-- length: 1
SELECT
assert(
__lda_check_count_ceiling(
array[1, -1, 4, 0]::bigint[],
array[-1, -1, -1, -1]::bigint[],
2,
2)
IS NOT NULL,
'__lda_check_count_ceiling should not return NULL for [1, -1, 4, 0]');
'__lda_check_count_ceiling should not return NULL for [-1, -1, -1, -1]');

SELECT lda_get_topic_desc(
'lda_model',
Expand All @@ -143,3 +162,22 @@ FROM __lda_util_norm_dataset('lda_testing', 'norm_lda_vocab',
SELECT *
FROM __lda_util_conorm_data('lda_testing', 'lda_vocab',
'norm_lda_data_2', 'norm_lda_vocab_2');

-- both voc_size and topic_num are odd or even
SELECT lda_train(
'lda_training_odd_voc_size',
'lda_model_odd_voc_size_even_topic_num',
'lda_output_odd_voc_size_even_topic_num',
21, 6, 2, 3, 0.01);

SELECT lda_train(
'lda_training_odd_voc_size',
'lda_model_odd_voc_size_odd_topic_num',
'lda_output_odd_voc_size_odd_topic_num',
21, 5, 2, 3, 0.01);

SELECT lda_train(
'lda_training',
'lda_model_even_voc_size_even_topic_num',
'lda_output_even_voc_size_even_topic_num',
20, 6, 2, 3, 0.01);

0 comments on commit bf083f5

Please sign in to comment.