|
| 1 | +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from os import listdir, path, makedirs |
| 16 | +import random |
| 17 | +import sys |
| 18 | +import time |
| 19 | +import datetime |
| 20 | + |
| 21 | + |
| 22 | +def print_stats(data): |
| 23 | + total_ratings = 0 |
| 24 | + print("STATS") |
| 25 | + for user in data: |
| 26 | + total_ratings += len(data[user]) |
| 27 | + print("Total Ratings: {}".format(total_ratings)) |
| 28 | + print("Total User count: {}".format(len(data.keys()))) |
| 29 | + |
| 30 | + |
| 31 | +def save_data_to_file(data, filename): |
| 32 | + with open(filename, 'w') as out: |
| 33 | + for userId in data: |
| 34 | + for record in data[userId]: |
| 35 | + out.write("{}\t{}\t{}\n".format(userId, record[0], record[1])) |
| 36 | + |
| 37 | + |
| 38 | +def create_NETFLIX_data_timesplit(all_data, train_min, train_max, test_min, |
| 39 | + test_max): |
| 40 | + """ |
| 41 | + Creates time-based split of NETFLIX data into train, and (validation, test) |
| 42 | + :param all_data: |
| 43 | + :param train_min: |
| 44 | + :param train_max: |
| 45 | + :param test_min: |
| 46 | + :param test_max: |
| 47 | + :return: |
| 48 | + """ |
| 49 | + train_min_ts = time.mktime( |
| 50 | + datetime.datetime.strptime(train_min, "%Y-%m-%d").timetuple()) |
| 51 | + train_max_ts = time.mktime( |
| 52 | + datetime.datetime.strptime(train_max, "%Y-%m-%d").timetuple()) |
| 53 | + test_min_ts = time.mktime( |
| 54 | + datetime.datetime.strptime(test_min, "%Y-%m-%d").timetuple()) |
| 55 | + test_max_ts = time.mktime( |
| 56 | + datetime.datetime.strptime(test_max, "%Y-%m-%d").timetuple()) |
| 57 | + |
| 58 | + training_data = dict() |
| 59 | + validation_data = dict() |
| 60 | + test_data = dict() |
| 61 | + |
| 62 | + train_set_items = set() |
| 63 | + |
| 64 | + for userId, userRatings in all_data.items(): |
| 65 | + time_sorted_ratings = sorted( |
| 66 | + userRatings, key=lambda x: x[2]) # sort by timestamp |
| 67 | + for rating_item in time_sorted_ratings: |
| 68 | + if rating_item[2] >= train_min_ts and rating_item[ |
| 69 | + 2] <= train_max_ts: |
| 70 | + if userId not in training_data: |
| 71 | + training_data[userId] = [] |
| 72 | + training_data[userId].append(rating_item) |
| 73 | + train_set_items.add( |
| 74 | + rating_item[0]) # keep track of items from training set |
| 75 | + elif rating_item[2] >= test_min_ts and rating_item[ |
| 76 | + 2] <= test_max_ts: |
| 77 | + if userId not in training_data: |
| 78 | + # only include users seen in the training set |
| 79 | + continue |
| 80 | + p = random.random() |
| 81 | + if p <= 0.5: |
| 82 | + if userId not in validation_data: |
| 83 | + validation_data[userId] = [] |
| 84 | + validation_data[userId].append(rating_item) |
| 85 | + else: |
| 86 | + if userId not in test_data: |
| 87 | + test_data[userId] = [] |
| 88 | + test_data[userId].append(rating_item) |
| 89 | + |
| 90 | + # remove items not not seen in training set |
| 91 | + for userId, userRatings in test_data.items(): |
| 92 | + test_data[userId] = [ |
| 93 | + rating for rating in userRatings if rating[0] in train_set_items |
| 94 | + ] |
| 95 | + for userId, userRatings in validation_data.items(): |
| 96 | + validation_data[userId] = [ |
| 97 | + rating for rating in userRatings if rating[0] in train_set_items |
| 98 | + ] |
| 99 | + |
| 100 | + return training_data, validation_data, test_data |
| 101 | + |
| 102 | + |
| 103 | +def main(args): |
| 104 | + user2id_map = dict() |
| 105 | + item2id_map = dict() |
| 106 | + userId = 0 |
| 107 | + itemId = 0 |
| 108 | + all_data = dict() |
| 109 | + |
| 110 | + folder = args[1] |
| 111 | + out_folder = args[2] |
| 112 | + # create necessary folders: |
| 113 | + for output_dir in [(out_folder + f) |
| 114 | + for f in ["/NF_TRAIN", "/NF_VALID", "/NF_TEST"]]: |
| 115 | + makedirs(output_dir, exist_ok=True) |
| 116 | + |
| 117 | + text_files = [ |
| 118 | + path.join(folder, f) for f in listdir(folder) |
| 119 | + if path.isfile(path.join(folder, f)) and ('.txt' in f) |
| 120 | + ] |
| 121 | + |
| 122 | + for text_file in text_files: |
| 123 | + with open(text_file, 'r') as f: |
| 124 | + print("Processing: {}".format(text_file)) |
| 125 | + lines = f.readlines() |
| 126 | + item = int(lines[0][:-2]) # remove newline and : |
| 127 | + if item not in item2id_map: |
| 128 | + item2id_map[item] = itemId |
| 129 | + itemId += 1 |
| 130 | + |
| 131 | + for rating in lines[1:]: |
| 132 | + parts = rating.strip().split(",") |
| 133 | + user = int(parts[0]) |
| 134 | + if user not in user2id_map: |
| 135 | + user2id_map[user] = userId |
| 136 | + userId += 1 |
| 137 | + rating = float(parts[1]) |
| 138 | + ts = int( |
| 139 | + time.mktime( |
| 140 | + datetime.datetime.strptime(parts[2], "%Y-%m-%d") |
| 141 | + .timetuple())) |
| 142 | + if user2id_map[user] not in all_data: |
| 143 | + all_data[user2id_map[user]] = [] |
| 144 | + all_data[user2id_map[user]].append( |
| 145 | + (item2id_map[item], rating, ts)) |
| 146 | + |
| 147 | + print("STATS FOR ALL INPUT DATA") |
| 148 | + print_stats(all_data) |
| 149 | + |
| 150 | + # Netflix full |
| 151 | + (nf_train, nf_valid, nf_test) = create_NETFLIX_data_timesplit( |
| 152 | + all_data, "1999-12-01", "2005-11-30", "2005-12-01", "2005-12-31") |
| 153 | + print("Netflix full train") |
| 154 | + print_stats(nf_train) |
| 155 | + save_data_to_file(nf_train, out_folder + "/NF_TRAIN/nf.train.txt") |
| 156 | + print("Netflix full valid") |
| 157 | + print_stats(nf_valid) |
| 158 | + save_data_to_file(nf_valid, out_folder + "/NF_VALID/nf.valid.txt") |
| 159 | + print("Netflix full test") |
| 160 | + print_stats(nf_test) |
| 161 | + save_data_to_file(nf_test, out_folder + "/NF_TEST/nf.test.txt") |
| 162 | + |
| 163 | + (n3m_train, n3m_valid, n3m_test) = create_NETFLIX_data_timesplit( |
| 164 | + all_data, "2005-09-01", "2005-11-30", "2005-12-01", "2005-12-31") |
| 165 | + print("Netflix 3m train") |
| 166 | + print_stats(n3m_train) |
| 167 | + save_data_to_file(n3m_train, out_folder + "/N3M_TRAIN/n3m.train.txt") |
| 168 | + print("Netflix 3m valid") |
| 169 | + print_stats(n3m_valid) |
| 170 | + save_data_to_file(n3m_valid, out_folder + "/N3M_VALID/n3m.valid.txt") |
| 171 | + print("Netflix 3m test") |
| 172 | + print_stats(n3m_test) |
| 173 | + save_data_to_file(n3m_test, out_folder + "/N3M_TEST/n3m.test.txt") |
| 174 | + |
| 175 | + (n6m_train, n6m_valid, n6m_test) = create_NETFLIX_data_timesplit( |
| 176 | + all_data, "2005-06-01", "2005-11-30", "2005-12-01", "2005-12-31") |
| 177 | + print("Netflix 6m train") |
| 178 | + print_stats(n6m_train) |
| 179 | + save_data_to_file(n6m_train, out_folder + "/N6M_TRAIN/n6m.train.txt") |
| 180 | + print("Netflix 6m valid") |
| 181 | + print_stats(n6m_valid) |
| 182 | + save_data_to_file(n6m_valid, out_folder + "/N6M_VALID/n6m.valid.txt") |
| 183 | + print("Netflix 6m test") |
| 184 | + print_stats(n6m_test) |
| 185 | + save_data_to_file(n6m_test, out_folder + "/N6M_TEST/n6m.test.txt") |
| 186 | + |
| 187 | + # Netflix 1 year |
| 188 | + (n1y_train, n1y_valid, n1y_test) = create_NETFLIX_data_timesplit( |
| 189 | + all_data, "2004-06-01", "2005-05-31", "2005-06-01", "2005-06-30") |
| 190 | + print("Netflix 1y train") |
| 191 | + print_stats(n1y_train) |
| 192 | + save_data_to_file(n1y_train, out_folder + "/N1Y_TRAIN/n1y.train.txt") |
| 193 | + print("Netflix 1y valid") |
| 194 | + print_stats(n1y_valid) |
| 195 | + save_data_to_file(n1y_valid, out_folder + "/N1Y_VALID/n1y.valid.txt") |
| 196 | + print("Netflix 1y test") |
| 197 | + print_stats(n1y_test) |
| 198 | + save_data_to_file(n1y_test, out_folder + "/N1Y_TEST/n1y.test.txt") |
| 199 | + |
| 200 | + |
| 201 | +if __name__ == "__main__": |
| 202 | + main(sys.argv) |
0 commit comments