Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/NeuMF/regression/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import argparse
import inspect

import torch


class Config:
device = torch.device("cuda:0")
# device = torch.device("cpu")
train_epochs = 10
batch_size = 128
learning_rate = 0.01
l2_regularization = 1e-3 # 正则化系数
learning_rate_decay = 0.99 # 学习率衰减程度

dataset_file = 'rllm/rllm/datasets/rel-movielens1m/regression/ratings/'

mf_dim = 10
mlp_layers = [32, 16, 8]

def __init__(self):
attributes = inspect.getmembers(self, lambda a: not inspect.isfunction(a))
attributes = list(filter(lambda x: not x[0].startswith('__'), attributes))

parser = argparse.ArgumentParser()
for key, val in attributes:
parser.add_argument('--' + key, dest=key, type=type(val), default=val)
for key, val in parser.parse_args().__dict__.items():
self.__setattr__(key, val)

def __str__(self):
attributes = inspect.getmembers(self, lambda a: not inspect.isfunction(a))
attributes = list(filter(lambda x: not x[0].startswith('__'), attributes))
to_str = ''
for key, val in attributes:
to_str += '{} = {}\n'.format(key, val)
return to_str
94 changes: 94 additions & 0 deletions examples/NeuMF/regression/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from torch import nn


class GMF(nn.Module):
def __init__(self, num_user, num_item, mf_dim=10, trainable=True):
super().__init__()
self.trainable = trainable
self.mf_user_emb = nn.Embedding(num_embeddings=num_user, embedding_dim=mf_dim)
self.mf_item_emb = nn.Embedding(num_embeddings=num_item, embedding_dim=mf_dim)
if trainable: # 预训练
self.linear = nn.Sequential(
nn.Linear(mf_dim, 1),
#nn.Sigmoid()
)
else:
trained = torch.load('weights/GMF.pt').state_dict()
for name, val in self.named_parameters():
val.data = trained[name]
val.requires_grad = False

def forward(self, user_id, item_id):
mf_vec = self.mf_user_emb(user_id) * self.mf_item_emb(item_id)
if self.trainable:
pred = self.linear(mf_vec)
return pred.squeeze()
else:
return mf_vec


class MLP(nn.Module):
def __init__(self, num_user, num_item, mlp_layers=None, trainable=True):
super().__init__()
if mlp_layers is None:
mlp_layers = [10]
self.trainable = trainable
self.mlp_user_emb = nn.Embedding(num_embeddings=num_user, embedding_dim=mlp_layers[0] // 2)
self.mlp_item_emb = nn.Embedding(num_embeddings=num_item, embedding_dim=mlp_layers[0] // 2)
#print(self.mlp_user_emb)

self.mlp = nn.ModuleList()
for i in range(1, len(mlp_layers)):
self.mlp.append(nn.Linear(mlp_layers[i - 1], mlp_layers[i]))
self.mlp.append(nn.ReLU())
if trainable:
self.linear = nn.Sequential(
nn.Linear(mlp_layers[-1], 1),
#nn.Sigmoid()
)
else:
trained = torch.load('weights/MLP.pt').state_dict()
for name, val in self.named_parameters():
val.data = trained[name]
val.requires_grad = False

def forward(self, user_id, item_id):
# print(self.mlp_item_emb.num_embeddings)

# print(item_id.min())
# print(item_id.max())

# print(self.mlp_user_emb(user_id).size())
# print(self.mlp_item_emb(item_id).size())

mlp_vec = torch.cat([self.mlp_user_emb(user_id), self.mlp_item_emb(item_id)], dim=-1)
for layer in self.mlp:
mlp_vec = layer(mlp_vec)
if self.trainable:
prediction = self.linear(mlp_vec)
return prediction.squeeze()
else:
return mlp_vec


class NeuMF(nn.Module):

def __init__(self, num_user, num_item, mf_dim=10, mlp_layers=None, use_pretrain=True):
super().__init__()
if mlp_layers is None:
mlp_layers = [10]
self.gmf = GMF(num_user, num_item, mf_dim, trainable=not use_pretrain) # 默认直接使用预训练好的权重
self.mlp = MLP(num_user, num_item, mlp_layers=mlp_layers, trainable=not use_pretrain)
self.linear = nn.Sequential(
nn.Linear(mf_dim + mlp_layers[-1], 1),
#nn.Sigmoid()
)

def forward(self, user_id, item_id):
gmf_vec = self.gmf(user_id, item_id)
mlp_vec = self.mlp(user_id, item_id)
# NueMF
cat = torch.cat([gmf_vec, mlp_vec], dim=-1)
prediction = self.linear(cat)
return prediction.squeeze()
112 changes: 112 additions & 0 deletions examples/NeuMF/regression/regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# NeuMF for rating prediction in regression task
# Paper: Neural Collaborative Filtering (NIPS 2017)
# Test MSE Loss: 0.960542
# Runtime: 466s on RTX3060
# Cost: N/A
# Description: remove the sigmoid layer and replace the BCEloss with MSEloss in training
import os
import time
import pandas as pd

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader

from config import Config
from model import NeuMF, GMF, MLP
from utils import date, predict_mse, NCFDataset


def train(train_dataloader, valid_dataloader, model, config, model_path):
print(f'{date()}## Start the training!')
train_mse = predict_mse(model, train_dataloader, config.device)
valid_mse = predict_mse(model, valid_dataloader, config.device)
print(f'{date()}#### Initial train mse {train_mse:.6f}, validation mse {valid_mse:.6f}')
start_time = time.perf_counter()

opt = torch.optim.Adam(model.parameters(), config.learning_rate, weight_decay=config.l2_regularization)
lr_sch = torch.optim.lr_scheduler.ExponentialLR(opt, config.learning_rate_decay)

best_loss = 100
for epoch in range(config.train_epochs):
model.train() # 将模型设置为训练状态
total_loss, total_samples = 0, 0
for batch in train_dataloader:
user_id, item_id, ratings = [i.to(config.device) for i in batch]
predict = model(user_id, item_id)
loss = F.mse_loss(predict, ratings, reduction='mean')
opt.zero_grad()
loss.backward()
opt.step()

total_loss += loss.item() * len(predict)
total_samples += len(predict)

lr_sch.step()
model.eval() # 停止训练状态
valid_mse = predict_mse(model, valid_dataloader, config.device)
train_loss = total_loss / total_samples
print(f"{date()}#### Epoch {epoch:3d}; train mse {train_loss:.6f}; validation mse {valid_mse:.6f}")

if best_loss > valid_mse:
best_loss = valid_mse
torch.save(model, model_path)

end_time = time.perf_counter()
print(f'{date()}## End of training! Time used {end_time - start_time:.0f} seconds.')


def test(dataloader, model):
print(f'{date()}## Start the testing!')
start_time = time.perf_counter()
test_loss = predict_mse(model, dataloader, next(model.parameters()).device)
end_time = time.perf_counter()
print(f"{date()}## Test end, test mse is {test_loss:.6f}, time used {end_time - start_time:.0f} seconds.")


def main():
#加载参数
config = Config()
print(config)

#加载dataset
train_data = pd.read_csv(config.dataset_file + 'train.csv', usecols=[0, 1, 2])
train_data.columns = ['userID', 'itemID', 'rating']
valid_data = pd.read_csv(config.dataset_file + 'validation.csv', usecols=[0, 1, 2])
valid_data.columns = ['userID', 'itemID', 'rating']
test_data = pd.read_csv(config.dataset_file + 'test.csv', usecols=[0, 1, 2])
test_data.columns = ['userID', 'itemID', 'rating']

user_count = max(train_data['userID']) + 1
item_count = max(train_data['itemID']) + 1
print(f"{date()}## Dataset contains {train_data.shape[0]} records, {user_count} users and {item_count} items.")

train_dataset = NCFDataset(train_data)
valid_dataset = NCFDataset(valid_data)
test_dataset = NCFDataset(test_data)

train_dlr = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
valid_dlr = DataLoader(valid_dataset, batch_size=config.batch_size)
test_dlr = DataLoader(test_dataset, batch_size=config.batch_size)

os.makedirs('./weights', exist_ok=True) # 文件夹不存在则创建
print(f'{date()}############ 预训练MLP ###########################')
model_MLP = MLP(user_count, item_count, config.mlp_layers).to(config.device)
train(train_dlr, valid_dlr, model_MLP, config, 'weights/MLP.pt')
test(test_dlr, torch.load('weights/MLP.pt'))


print(f'{date()}############ 预训练GMF ###########################')
model_GMF = GMF(user_count, item_count, config.mf_dim).to(config.device)
train(train_dlr, valid_dlr, model_GMF, config, 'weights/GMF.pt')
test(test_dlr, torch.load('weights/GMF.pt'))


print(f'{date()}############ 训练NeuMF ###########################')
model_NeuMF = NeuMF(user_count, item_count, config.mf_dim, config.mlp_layers, use_pretrain=True).to(config.device)
train(train_dlr, valid_dlr, model_NeuMF, config, 'weights/NeuMF.pt')
test(test_dlr, torch.load('weights/NeuMF.pt'))


if __name__ == '__main__':
main()
41 changes: 41 additions & 0 deletions examples/NeuMF/regression/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import time
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from config import Config


class NCFDataset(Dataset):
def __init__(self, df):
self.user_id = torch.LongTensor(df['userID'].to_list())
print(len(df['itemID']))
self.item_id = torch.LongTensor(df['itemID'].to_list())
self.rating = torch.Tensor(df['rating'].to_list())

def __getitem__(self, idx):
return self.user_id[idx], self.item_id[idx], self.rating[idx]

def __len__(self):
return self.rating.shape[0]


def date(f='%Y-%m-%d %H:%M:%S'):
return time.strftime(f, time.localtime())


def predict_mse(trained_model, dataloader, device):
mse, sample_count = 0, 0
with torch.no_grad():
for batch in dataloader:
user_id, item_id, ratings = [i.to(device) for i in batch]
#print(user_id, item_id)
#print(user_id.size())
predict = trained_model(user_id, item_id)
mse += torch.nn.functional.mse_loss(predict, ratings, reduction='sum').item()
sample_count += len(ratings)
return mse / sample_count # dataloader上的均方误差