diff --git a/SarcasmDetector/README.md b/SarcasmDetector/README.md new file mode 100644 index 00000000..54a6889c --- /dev/null +++ b/SarcasmDetector/README.md @@ -0,0 +1,14 @@ +# Sarcasm Detector + +tldr? A text classifier that detects sarcasm. + +As is well known to fams and fans of The Big Bang Theory, Sheldon Cooper, one the most brilliant minds of 21st century American Entertainment, sadly struggles to detect sarcasm in casual conversation. So a bunch of Cool Women in Tech decided to partner up and solve this pressing problem in modern pop culture. + +In this repo we document this journey for the reference of techies undertaking a similar quest in the near or far future! We hope you enjoy the process as much as we did, as hope to make it easier for you to follow along! ^_^ + + +>contributors + +>technical details + +> diff --git a/SaveSheldon/sarcasm.ipynb b/SaveSheldon/sarcasm.ipynb new file mode 100644 index 00000000..86c5bfac --- /dev/null +++ b/SaveSheldon/sarcasm.ipynb @@ -0,0 +1,1772 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "sarcasm.ipynb", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "metadata": { + "id": "BHq8guDSTy8q", + "colab_type": "code", + "outputId": "66d42816-4aba-487a-933a-2d8207b886ee", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 122 + } + }, + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code\n", + "\n", + "Enter your authorization code:\n", + "··········\n", + "Mounted at /content/drive\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "SimktZGaTz5H", + "colab_type": "code", + "outputId": "889df6a8-5791-4a1b-c15f-6a1228997239", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "!unzip '/content/drive/My Drive/train-balanced-sarcasm.csv.zip'" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Archive: /content/drive/My Drive/train-balanced-sarcasm.csv.zip\n", + " inflating: train-balanced-sarcasm.csv \n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "1JuN0ZYyW1SL", + "colab_type": "code", + "outputId": "c00ea784-5380-484c-e730-24e78d3498e2", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + } + }, + "source": [ + "import numpy as np # linear algebra\n", + "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", + "from plotly import tools\n", + "import plotly.offline as py\n", + "py.init_notebook_mode(connected=True)\n", + "import plotly.graph_objs as go\n", + "from collections import defaultdict\n", + "from matplotlib import pyplot as plt\n", + "%matplotlib inline\n", + "import re\n", + "import torch\n", + "from torch.utils.data import DataLoader, TensorDataset" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/vnd.plotly.v1+html": "", + "text/html": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2lQDgLiVTz60", + "colab_type": "code", + "colab": {} + }, + "source": [ + "train_data=pd.read_csv('/content/train-balanced-sarcasm.csv')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Y1WgCA09Tz93", + "colab_type": "code", + "outputId": "49d327d5-e483-44c3-9ae6-c829ef389bf1", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 374 + } + }, + "source": [ + "train_data.head()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelcommentauthorsubredditscoreupsdownsdatecreated_utcparent_comment
00NC and NH.Trumpbartpolitics2-1-12016-102016-10-16 23:55:23Yeah, I get that argument. At this point, I'd ...
10You do know west teams play against west teams...Shbshb906nba-4-1-12016-112016-11-01 00:24:10The blazers and Mavericks (The wests 5 and 6 s...
20They were underdogs earlier today, but since G...Creepethnfl3302016-092016-09-22 21:45:37They're favored to win.
30This meme isn't funny none of the \"new york ni...icebrothaBlackPeopleTwitter-8-1-12016-102016-10-18 21:03:47deadass don't kill my buzz
40I could use one of those tools.cush2pushMaddenUltimateTeam6-1-12016-122016-12-30 17:00:13Yep can confirm I saw the tool they use for th...
\n", + "
" + ], + "text/plain": [ + " label ... parent_comment\n", + "0 0 ... Yeah, I get that argument. At this point, I'd ...\n", + "1 0 ... The blazers and Mavericks (The wests 5 and 6 s...\n", + "2 0 ... They're favored to win.\n", + "3 0 ... deadass don't kill my buzz\n", + "4 0 ... Yep can confirm I saw the tool they use for th...\n", + "\n", + "[5 rows x 10 columns]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 18 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ogouztYRUChv", + "colab_type": "code", + "outputId": "fc2e8096-111d-41b9-8556-8b7ba587d6c1", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 272 + } + }, + "source": [ + "train_data.info()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n", + "RangeIndex: 1010826 entries, 0 to 1010825\n", + "Data columns (total 10 columns):\n", + "label 1010826 non-null int64\n", + "comment 1010773 non-null object\n", + "author 1010826 non-null object\n", + "subreddit 1010826 non-null object\n", + "score 1010826 non-null int64\n", + "ups 1010826 non-null int64\n", + "downs 1010826 non-null int64\n", + "date 1010826 non-null object\n", + "created_utc 1010826 non-null object\n", + "parent_comment 1010826 non-null object\n", + "dtypes: int64(4), object(6)\n", + "memory usage: 77.1+ MB\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Yil_0nYRTz_v", + "colab_type": "code", + "colab": {} + }, + "source": [ + "train_data.dropna(subset=['comment'], inplace=True)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "7lbG1vVsT0C8", + "colab_type": "code", + "outputId": "1ea56472-0789-4d02-f1e3-beba074c9d63", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 272 + } + }, + "source": [ + "train_data.info()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n", + "Int64Index: 1010773 entries, 0 to 1010825\n", + "Data columns (total 10 columns):\n", + "label 1010773 non-null int64\n", + "comment 1010773 non-null object\n", + "author 1010773 non-null object\n", + "subreddit 1010773 non-null object\n", + "score 1010773 non-null int64\n", + "ups 1010773 non-null int64\n", + "downs 1010773 non-null int64\n", + "date 1010773 non-null object\n", + "created_utc 1010773 non-null object\n", + "parent_comment 1010773 non-null object\n", + "dtypes: int64(4), object(6)\n", + "memory usage: 84.8+ MB\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "a2l6PiF4T0He", + "colab_type": "code", + "outputId": "f4141444-733f-4964-87ac-a4299e6f440c", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "train_data['label'].unique()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([0, 1])" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 21 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "1RmuDIUxT0MM", + "colab_type": "code", + "outputId": "237c9949-4516-4305-f291-735e0ca1de32", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + } + }, + "source": [ + "train_data.loc[train_data['label'] == 1, 'comment'].str.len()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "33 35\n", + "44 145\n", + "45 107\n", + "66 67\n", + "69 65\n", + "74 40\n", + "76 116\n", + "160 30\n", + "169 56\n", + "193 182\n", + "194 58\n", + "216 188\n", + "230 56\n", + "255 100\n", + "274 22\n", + "278 60\n", + "305 33\n", + "336 42\n", + "346 34\n", + "359 28\n", + "370 39\n", + "443 96\n", + "457 28\n", + "479 139\n", + "487 43\n", + "497 35\n", + "508 34\n", + "526 26\n", + "556 48\n", + "575 5\n", + " ... \n", + "1010794 69\n", + "1010795 49\n", + "1010796 5\n", + "1010797 65\n", + "1010798 20\n", + "1010799 9\n", + "1010800 64\n", + "1010801 143\n", + "1010802 184\n", + "1010803 21\n", + "1010805 38\n", + "1010806 20\n", + "1010807 219\n", + "1010808 53\n", + "1010809 338\n", + "1010810 23\n", + "1010811 17\n", + "1010812 65\n", + "1010813 25\n", + "1010814 131\n", + "1010816 13\n", + "1010817 77\n", + "1010818 34\n", + "1010819 38\n", + "1010820 2\n", + "1010821 92\n", + "1010822 34\n", + "1010823 66\n", + "1010824 53\n", + "1010825 72\n", + "Name: comment, Length: 505368, dtype: int64" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 22 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4h-5SFSaT0OM", + "colab_type": "code", + "outputId": "0b8c922d-8b27-4bd6-aea5-6302a60b4fa1", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 785 + } + }, + "source": [ + "cnt_label = train_data['label'].value_counts()\n", + "print(cnt_label)\n", + "labels = (np.array(cnt_label.index))\n", + "sizes=(np.array((cnt_label/cnt_label.sum())*100))\n", + "print(sizes)\n", + "trace = go.Pie(labels=labels, values=sizes)\n", + "layout = go.Layout(title='Label Distribution', font=dict(size=15),\n", + " width=700, height=700)\n", + "\n", + "data = [trace]\n", + "fig = go.Figure(data=data, layout=layout)\n", + "py.iplot(fig, filename=\"LabelDistribution\")" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "0 505405\n", + "1 505368\n", + "Name: label, dtype: int64\n", + "[50.00183028 49.99816972]\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "text/vnd.plotly.v1+html": "
", + "text/html": [ + "
" + ], + "application/vnd.plotly.v1+json": { + "layout": { + "title": { + "text": "Label Distribution" + }, + "font": { + "size": 15 + }, + "width": 700, + "height": 700 + }, + "config": { + "plotlyServerURL": "https://plot.ly", + "linkText": "Export to plot.ly", + "showLink": false + }, + "data": [ + { + "type": "pie", + "labels": [ + 0, + 1 + ], + "values": [ + 50.001830282368054, + 49.99816971763195 + ], + "uid": "7e84568b-b29c-4213-a806-387dca1a08ed" + } + ] + } + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "mbtWNnH3bype", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def horizontal_bar_chart(df, color):\n", + " trace = go.Bar(y = df[\"word\"].values[::-1],\n", + " x = df[\"wordcount\"].values[::-1],\n", + " showlegend=False,\n", + " orientation='h',\n", + " marker=dict(color=color))\n", + " return trace" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "_1TAGy0wT0Ra", + "colab_type": "code", + "colab": {} + }, + "source": [ + "train1_df = train_data[train_data['label']==1]\n", + "train0_df = train_data[train_data['label']==0]\n", + "from wordcloud import WordCloud, STOPWORDS\n", + "def generate_ngrams(text, n_gram=1):\n", + " token = [token for token in text.lower().split(\" \") if token!= \"\" if token not\n", + " in STOPWORDS]\n", + " \n", + " ngrams = zip(*[token[i:] for i in range(n_gram)])\n", + " return [\" \".join(ngram) for ngram in ngrams]\n", + "\n", + "freq_dict = defaultdict(int)\n", + "for sent in train0_df[\"comment\"]:\n", + " for word in generate_ngrams(sent):\n", + " freq_dict[word] += 1\n", + "fd_sorted = pd.DataFrame(sorted(freq_dict.items(), key=lambda x: x[1])[::-1])\n", + "fd_sorted.columns = [\"word\", \"wordcount\"]\n", + "trace1 = horizontal_bar_chart(fd_sorted.head(50), 'blue')\n", + " " + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "_agqrum1T0WM", + "colab_type": "code", + "outputId": "2b002a72-3ff2-42f4-adbf-120abbc8967a", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 390 + } + }, + "source": [ + "sub_df = train_data.groupby('subreddit')['label'].agg([np.size, np.mean, np.sum])\n", + "sub_df.sort_values(by='sum', ascending=False).head(10)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sizemeansum
subreddit
AskReddit656740.40145326365
politics394930.60534823907
worldnews263760.64251616947
leagueoflegends210340.54231211407
pcmasterrace189870.56665110759
news168910.60345710193
funny179390.4514748099
pics161520.4843367823
todayilearned141590.5475677753
GlobalOffensive137380.5520457584
\n", + "
" + ], + "text/plain": [ + " size mean sum\n", + "subreddit \n", + "AskReddit 65674 0.401453 26365\n", + "politics 39493 0.605348 23907\n", + "worldnews 26376 0.642516 16947\n", + "leagueoflegends 21034 0.542312 11407\n", + "pcmasterrace 18987 0.566651 10759\n", + "news 16891 0.603457 10193\n", + "funny 17939 0.451474 8099\n", + "pics 16152 0.484336 7823\n", + "todayilearned 14159 0.547567 7753\n", + "GlobalOffensive 13738 0.552045 7584" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 26 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "e4DWndGiT0Yv", + "colab_type": "code", + "outputId": "00564040-bd8b-4a07-e125-0c21b10c5536", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 390 + } + }, + "source": [ + "sub_df[sub_df['size'] > 1000].sort_values(by='mean', ascending=False).head(10)\n" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sizemeansum
subreddit
creepyPMs54660.7843034287
MensRights33550.6807752284
ShitRedditSays12840.661994850
worldnews263760.64251616947
Libertarian25620.6401251640
atheism73770.6395554718
Conservative18810.6395531203
TwoXChromosomes15600.632692987
fatlogic23560.6230901468
facepalm12680.617508783
\n", + "
" + ], + "text/plain": [ + " size mean sum\n", + "subreddit \n", + "creepyPMs 5466 0.784303 4287\n", + "MensRights 3355 0.680775 2284\n", + "ShitRedditSays 1284 0.661994 850\n", + "worldnews 26376 0.642516 16947\n", + "Libertarian 2562 0.640125 1640\n", + "atheism 7377 0.639555 4718\n", + "Conservative 1881 0.639553 1203\n", + "TwoXChromosomes 1560 0.632692 987\n", + "fatlogic 2356 0.623090 1468\n", + "facepalm 1268 0.617508 783" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 27 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "np0jP520T0bL", + "colab_type": "code", + "colab": {} + }, + "source": [ + "train_data['comment']=train_data['comment'].str.lower()\n", + "train_data['parent_comment']=(train_data['parent_comment']).str.lower()\n", + "train_data['author']=(train_data['comment']).str.lower()\n", + "train_data['subreddit']=(train_data['subreddit']).str.lower()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "VFPRDEpjT0gQ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "train_data['comment']=train_data['comment'].replace('[^\\w\\s]','', regex=True)\n", + "train_data['parent_comment']=train_data['parent_comment'].replace('[^\\w\\s]','', regex=True)\n", + "train_data['author']=train_data['author'].replace('[^\\w\\s]','', regex=True)\n", + "train_data['subreddit']=train_data['subreddit'].replace('[^\\w\\s]','', regex=True)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "H-SaNb4U_jrg", + "colab_type": "code", + "outputId": "51b60e43-6db3-459b-91f4-af992a1e1060", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "train_data['parent_comment'][4]" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'yep can confirm i saw the tool they use for that it was made by our boy easports_mut'" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 18 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "83lTIy9sSP4f", + "colab_type": "code", + "outputId": "5fbff3cd-e2c0-493d-b22a-c68bad7a1dfe", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 261 + } + }, + "source": [ + "train_data.head(3)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelcommentauthorsubredditscoreupsdownsdatecreated_utcparent_comment
00nc and nhnc and nhpolitics2-1-12016-102016-10-16 23:55:23yeah i get that argument at this point id pref...
10you do know west teams play against west teams...you do know west teams play against west teams...nba-4-1-12016-112016-11-01 00:24:10the blazers and mavericks the wests 5 and 6 se...
20they were underdogs earlier today but since gr...they were underdogs earlier today but since gr...nfl3302016-092016-09-22 21:45:37theyre favored to win
\n", + "
" + ], + "text/plain": [ + " label ... parent_comment\n", + "0 0 ... yeah i get that argument at this point id pref...\n", + "1 0 ... the blazers and mavericks the wests 5 and 6 se...\n", + "2 0 ... theyre favored to win\n", + "\n", + "[3 rows x 10 columns]" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 19 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "tKQREGVsT0k7", + "colab_type": "code", + "colab": {} + }, + "source": [ + "train_data.drop(['author','subreddit','score','ups','downs','date','created_utc'], axis=1, inplace=True)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "PPwJAPMhbMOU", + "colab_type": "code", + "colab": {} + }, + "source": [ + "train_data.drop(['parent_comment'], axis=1, inplace=True)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "R3jSvxvG-5f5", + "colab_type": "code", + "outputId": "9c99724b-b742-4bbb-f706-349e9dc1e872", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "from string import punctuation\n", + "print(punctuation)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "IadaqnSA-_sI", + "colab_type": "code", + "colab": {} + }, + "source": [ + "all_reviews=list()\n", + "for text in train_data['comment']:\n", + " text = text.lower()\n", + " text = \"\".join([ch for ch in text if ch not in punctuation])\n", + " all_reviews.append(text)\n", + "all_text = \" \".join(all_reviews)\n", + "all_words = all_text.split()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "sqMdU79Z_N_R", + "colab_type": "code", + "outputId": "6937b862-1f5c-4284-a52e-e3470f77a251", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "print(all_words[:10])" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "['nc', 'and', 'nh', 'you', 'do', 'know', 'west', 'teams', 'play', 'against']\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lGKHqJNK_6-g", + "colab_type": "code", + "outputId": "87ce60ac-e1b6-414e-b5a6-368dd7a399b3", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + } + }, + "source": [ + "print(type(all_words))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "7ic1z5Nh_OMi", + "colab_type": "code", + "outputId": "d43e82d8-8b15-4b07-8eaf-6e0549ed4f07", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 54 + } + }, + "source": [ + "from collections import Counter \n", + "# Count all the words using Counter Method\n", + "count_words = Counter(all_words)\n", + "total_words=len(all_words)\n", + "(sorted_words)=count_words.most_common(total_words)\n", + "print(sorted_words[:10])\n", + "#print(\"Top ten occuring words : \"+sorted_words[:10])" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[('the', 396228), ('a', 244958), ('to', 239184), ('i', 184355), ('and', 173669), ('you', 171489), ('is', 154657), ('of', 149812), ('that', 140694), ('it', 128142)]\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0q4X41SsGzUn", + "colab_type": "code", + "colab": {} + }, + "source": [ + "vocab_to_int={w:i+1 for i,(w,c) in enumerate(sorted_words)}\n", + "#print(vocab_to_int)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "in-yWAc7GzSH", + "colab_type": "code", + "colab": {} + }, + "source": [ + "encoded_reviews=list()\n", + "for review in all_reviews:\n", + " encoded_review=list()\n", + " for word in review.split():\n", + " if word not in vocab_to_int.keys():\n", + " #if word is not available in vocab_to_int put 0 in that place\n", + " encoded_review.append(0)\n", + " else:\n", + " encoded_review.append(vocab_to_int[word])\n", + " encoded_reviews.append(encoded_review)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "qaK3zcKOIKJ-", + "colab_type": "code", + "outputId": "1d0be026-1b2b-431c-953d-59b192e8a536", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 422 + } + }, + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "reviews_len = [len(x) for x in encoded_reviews]\n", + "pd.Series(reviews_len).hist()\n", + "plt.show()\n", + "pd.Series(reviews_len).describe()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAD8CAYAAACyyUlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEzNJREFUeJzt3X+s3fV93/Hna7ikLm2CgekK2Wh2\nV2uTG7SVXBFPqaKrMhlDp5lJbUSEhpeh+I+QLl2YFmf9gypRpGQazQJKkbzixVQolKWZbG0mnkdy\nVO0PCJBSzI9R3xJSbBloMYE6UZM5fe+P8/F2rn3vNT4fm2Pf+3xIR/d73t/P9/v5nrfO5eXz/X7v\nIVWFJEk9/takD0CSdOEzTCRJ3QwTSVI3w0SS1M0wkSR1M0wkSd0ME0lSN8NEktTNMJEkdVsx6QN4\np1xxxRW1du3asbb9wQ9+wCWXXHJ2D+gCZ0/msh+nsidzXaj9ePLJJ/+yqv726cYtmzBZu3YtTzzx\nxFjbDgYDZmZmzu4BXeDsyVz241T2ZK4LtR9Jvvd2xnmaS5LUzTCRJHUzTCRJ3QwTSVI3w0SS1M0w\nkSR1M0wkSd1OGyZJdiZ5LckzI7XLkuxPcrD9XNXqSXJ3ktkkTye5ZmSbrW38wSRbR+rvS3KgbXN3\nkow7hyRpMt7OJ5OvAJtPqm0HHqmq9cAj7TnADcD69tgG3AvDYADuBN4PXAvceSIc2piPjmy3eZw5\nJEmTc9q/gK+qP0qy9qTyFmCmLe8CBsCnWv3+qirg0SSXJrmyjd1fVUcBkuwHNicZAO+uqkdb/X7g\nJuDhM52jqo6c2Ut/+w4cfpN/sf2/n6vdL+qlz//qROaVpDMx7jWTqZH/eL8CTLXl1cDLI+MOtdpi\n9UPz1MeZQ5I0Id3fzVVVlaTOxsGc7TmSbGN4KoypqSkGg8FY80+thDuuPj7Wtr3GPeZz7dixY+ft\nsU2C/TiVPZlrqfdj3DB59cSppXYa67VWPwxcNTJuTasd5v+fsjpRH7T6mnnGjzPHKapqB7ADYHp6\nusb9krV7HtjNXQcm852YL90yM5F5T+dC/dK6c8V+nMqezLXU+zHuaa49wIk7srYCu0fqt7Y7rjYC\nb7ZTVfuATUlWtQvvm4B9bd1bSTa2u7huPWlfZzKHJGlCTvvP7SRfZfip4ookhxjelfV54KEktwHf\nAz7Uhu8FbgRmgR8CHwGoqqNJPgs83sZ95sTFeOBjDO8YW8nwwvvDrX5Gc0iSJuft3M314QVWXTfP\n2AJuX2A/O4Gd89SfAN47T/31M51DkjQZ/gW8JKmbYSJJ6maYSJK6GSaSpG6GiSSpm2EiSepmmEiS\nuhkmkqRuhokkqZthIknqZphIkroZJpKkboaJJKmbYSJJ6maYSJK6GSaSpG6GiSSpm2EiSepmmEiS\nuhkmkqRuhokkqZthIknqZphIkroZJpKkboaJJKmbYSJJ6maYSJK6GSaSpG6GiSSpm2EiSepmmEiS\nuhkmkqRuXWGS5F8neTbJM0m+muSnk6xL8liS2SR/kOTiNvZd7flsW792ZD+fbvUXklw/Ut/carNJ\nto/U551DkjQZY4dJktXAvwKmq+q9wEXAzcAXgC9W1S8AbwC3tU1uA95o9S+2cSTZ0Lb7RWAz8LtJ\nLkpyEfBl4AZgA/DhNpZF5pAkTUDvaa4VwMokK4CfAY4AvwJ8ra3fBdzUlre057T11yVJqz9YVT+q\nqu8Cs8C17TFbVS9W1Y+BB4EtbZuF5pAkTcDYYVJVh4H/APw5wxB5E3gS+H5VHW/DDgGr2/Jq4OW2\n7fE2/vLR+knbLFS/fJE5JEkTsGLcDZOsYvipYh3wfeC/MDxNdd5Isg3YBjA1NcVgMBhrP1Mr4Y6r\nj59+4Dkw7jGfa8eOHTtvj20S7Mep7MlcS70fY4cJ8I+B71bVXwAk+TrwAeDSJCvaJ4c1wOE2/jBw\nFXConRZ7D/D6SP2E0W3mq7++yBxzVNUOYAfA9PR0zczMjPVC73lgN3cd6GnV+F66ZWYi857OYDBg\n3H4uRfbjVPZkrqXej55rJn8ObEzyM+06xnXAc8C3gF9rY7YCu9vynvactv6bVVWtfnO722sdsB74\nNvA4sL7duXUxw4v0e9o2C80hSZqAnmsmjzG8CP4d4EDb1w7gU8Ank8wyvL5xX9vkPuDyVv8ksL3t\n51ngIYZB9A3g9qr6SfvU8XFgH/A88FAbyyJzSJImoOvcTVXdCdx5UvlFhndinTz2r4FfX2A/nwM+\nN099L7B3nvq8c0iSJsO/gJckdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1\nM0wkSd0ME0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1\nM0wkSd0ME0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1K0rTJJc\nmuRrSf53kueT/KMklyXZn+Rg+7mqjU2Su5PMJnk6yTUj+9naxh9MsnWk/r4kB9o2dydJq887hyRp\nMno/mXwJ+EZV/X3gHwDPA9uBR6pqPfBIew5wA7C+PbYB98IwGIA7gfcD1wJ3joTDvcBHR7bb3OoL\nzSFJmoCxwyTJe4APAvcBVNWPq+r7wBZgVxu2C7ipLW8B7q+hR4FLk1wJXA/sr6qjVfUGsB/Y3Na9\nu6oeraoC7j9pX/PNIUmagJ5PJuuAvwD+c5I/TvJ7SS4BpqrqSBvzCjDVllcDL49sf6jVFqsfmqfO\nInNIkiZgRee21wC/UVWPJfkSJ51uqqpKUj0HeDqLzZFkG8NTakxNTTEYDMaaY2ol3HH18bGPsce4\nx3yuHTt27Lw9tkmwH6eyJ3Mt9X70hMkh4FBVPdaef41hmLya5MqqOtJOVb3W1h8GrhrZfk2rHQZm\nTqoPWn3NPONZZI45qmoHsANgenq6ZmZm5ht2Wvc8sJu7DvS0anwv3TIzkXlPZzAYMG4/lyL7cSp7\nMtdS78fYp7mq6hXg5SR/r5WuA54D9gAn7sjaCuxuy3uAW9tdXRuBN9upqn3ApiSr2oX3TcC+tu6t\nJBvbXVy3nrSv+eaQJE1A7z+3fwN4IMnFwIvARxgG1ENJbgO+B3yojd0L3AjMAj9sY6mqo0k+Czze\nxn2mqo625Y8BXwFWAg+3B8DnF5hDkjQBXWFSVU8B0/Osum6esQXcvsB+dgI756k/Abx3nvrr880h\nSZoM/wJektTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wkSd0M\nE0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wkSd0M\nE0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVK37jBJclGSP07y39rzdUke\nSzKb5A+SXNzq72rPZ9v6tSP7+HSrv5Dk+pH65labTbJ9pD7vHJKkyTgbn0w+ATw/8vwLwBer6heA\nN4DbWv024I1W/2IbR5INwM3ALwKbgd9tAXUR8GXgBmAD8OE2drE5JEkT0BUmSdYAvwr8Xnse4FeA\nr7Uhu4Cb2vKW9py2/ro2fgvwYFX9qKq+C8wC17bHbFW9WFU/Bh4EtpxmDknSBPR+MvmPwL8F/qY9\nvxz4flUdb88PAavb8mrgZYC2/s02/v/VT9pmofpic0iSJmDFuBsm+SfAa1X1ZJKZs3dIZ0+SbcA2\ngKmpKQaDwVj7mVoJd1x9/PQDz4Fxj/lcO3bs2Hl7bJNgP05lT+Za6v0YO0yADwD/NMmNwE8D7wa+\nBFyaZEX75LAGONzGHwauAg4lWQG8B3h9pH7C6Dbz1V9fZI45qmoHsANgenq6ZmZmxnqh9zywm7sO\n9LRqfC/dMjOReU9nMBgwbj+XIvtxKnsy11Lvx9inuarq01W1pqrWMryA/s2qugX4FvBrbdhWYHdb\n3tOe09Z/s6qq1W9ud3utA9YD3wYeB9a3O7cubnPsadssNIckaQLOxd+ZfAr4ZJJZhtc37mv1+4DL\nW/2TwHaAqnoWeAh4DvgGcHtV/aR96vg4sI/h3WIPtbGLzSFJmoCzcu6mqgbAoC2/yPBOrJPH/DXw\n6wts/zngc/PU9wJ756nPO4ckaTL8C3hJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wk\nSd0ME0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wk\nSd0ME0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wk\nSd3GDpMkVyX5VpLnkjyb5BOtflmS/UkOtp+rWj1J7k4ym+TpJNeM7GtrG38wydaR+vuSHGjb3J0k\ni80hSZqMnk8mx4E7qmoDsBG4PckGYDvwSFWtBx5pzwFuANa3xzbgXhgGA3An8H7gWuDOkXC4F/jo\nyHabW32hOSRJEzB2mFTVkar6Tlv+K+B5YDWwBdjVhu0CbmrLW4D7a+hR4NIkVwLXA/ur6mhVvQHs\nBza3de+uqkerqoD7T9rXfHNIkibgrFwzSbIW+CXgMWCqqo60Va8AU215NfDyyGaHWm2x+qF56iwy\nhyRpAlb07iDJzwJ/CPxmVb3VLmsAUFWVpHrnWMxicyTZxvCUGlNTUwwGg7HmmFoJd1x9fOxj7DHu\nMZ9rx44dO2+PbRLsx6nsyVxLvR9dYZLkpxgGyQNV9fVWfjXJlVV1pJ2qeq3VDwNXjWy+ptUOAzMn\n1Qetvmae8YvNMUdV7QB2AExPT9fMzMx8w07rngd2c9eB7twdy0u3zExk3tMZDAaM28+lyH6cyp7M\ntdT70XM3V4D7gOer6ndGVu0BTtyRtRXYPVK/td3VtRF4s52q2gdsSrKqXXjfBOxr695KsrHNdetJ\n+5pvDknSBPT8c/sDwD8HDiR5qtX+HfB54KEktwHfAz7U1u0FbgRmgR8CHwGoqqNJPgs83sZ9pqqO\ntuWPAV8BVgIPtweLzCFJmoCxw6Sq/heQBVZfN8/4Am5fYF87gZ3z1J8A3jtP/fX55pAkTYZ/AS9J\n6maYSJK6GSaSpG6GiSSpm2EiSepmmEiSuhkmkqRuhokkqZthIknqZphIkroZJpKkboaJJKmbYSJJ\n6maYSJK6GSaSpG6GiSSpm2EiSepmmEiSuhkmkqRuhokkqZthIknqZphIkroZJpKkboaJJKmbYSJJ\n6maYSJK6GSaSpG6GiSSpm2EiSepmmEiSuhkmkqRuhokkqdsFGyZJNid5Iclsku2TPh5JWs4uyDBJ\nchHwZeAGYAPw4SQbJntUkrR8XZBhAlwLzFbVi1X1Y+BBYMuEj0mSlq0LNUxWAy+PPD/UapKkCVgx\n6QM4l5JsA7a1p8eSvDDmrq4A/vLsHNWZyRcmMevbMrGenKfsx6nsyVwXaj/+ztsZdKGGyWHgqpHn\na1ptjqraAezonSzJE1U13bufpcSezGU/TmVP5lrq/bhQT3M9DqxPsi7JxcDNwJ4JH5MkLVsX5CeT\nqjqe5OPAPuAiYGdVPTvhw5KkZeuCDBOAqtoL7H2Hpus+VbYE2ZO57Mep7MlcS7ofqapJH4Mk6QJ3\noV4zkSSdRwyT01iuX9uS5KUkB5I8leSJVrssyf4kB9vPVa2eJHe3Hj2d5JrJHv3ZkWRnkteSPDNS\nO+MeJNnaxh9MsnUSr+VsWKAfv53kcHufPJXkxpF1n279eCHJ9SP1JfE7leSqJN9K8lySZ5N8otWX\n53ukqnws8GB4cf/PgJ8HLgb+BNgw6eN6h177S8AVJ9X+PbC9LW8HvtCWbwQeBgJsBB6b9PGfpR58\nELgGeGbcHgCXAS+2n6va8qpJv7az2I/fBv7NPGM3tN+XdwHr2u/RRUvpdwq4ErimLf8c8KftdS/L\n94ifTBbn17bMtQXY1ZZ3ATeN1O+voUeBS5NcOYkDPJuq6o+AoyeVz7QH1wP7q+poVb0B7Ac2n/uj\nP/sW6MdCtgAPVtWPquq7wCzD36cl8ztVVUeq6jtt+a+A5xl+E8eyfI8YJotbzl/bUsD/SPJk+yYB\ngKmqOtKWXwGm2vJy6tOZ9mA59Obj7bTNzhOndFhm/UiyFvgl4DGW6XvEMNFCfrmqrmH4zcy3J/ng\n6Moafj5f1rcC2gMA7gX+LvAPgSPAXZM9nHdekp8F/hD4zap6a3TdcnqPGCaLe1tf27IUVdXh9vM1\n4L8yPD3x6onTV+3na234curTmfZgSfemql6tqp9U1d8A/4nh+wSWST+S/BTDIHmgqr7eysvyPWKY\nLG5Zfm1LkkuS/NyJZWAT8AzD137iTpOtwO62vAe4td2tshF4c+Rj/lJzpj3YB2xKsqqdAtrUakvC\nSdfG/hnD9wkM+3FzknclWQesB77NEvqdShLgPuD5qvqdkVXL8z0y6TsAzvcHwzsw/pThHSi/Nenj\neYde888zvMvmT4BnT7xu4HLgEeAg8D+By1o9DP9nZX8GHACmJ/0azlIfvsrw1M3/YXge+7ZxegD8\nS4YXoGeBj0z6dZ3lfvx+e71PM/yP5ZUj43+r9eMF4IaR+pL4nQJ+meEprKeBp9rjxuX6HvEv4CVJ\n3TzNJUnqZphIkroZJpKkboaJJKmbYSJJ6maYSJK6GSaSpG6GiSSp2/8Fqu+LlacobiMAAAAASUVO\nRK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "count 1.010773e+06\n", + "mean 1.042197e+01\n", + "std 1.050239e+01\n", + "min 0.000000e+00\n", + "25% 5.000000e+00\n", + "50% 9.000000e+00\n", + "75% 1.400000e+01\n", + "max 2.222000e+03\n", + "dtype: float64" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 39 + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "LBgg-Fe6JcF0", + "colab_type": "code", + "colab": {} + }, + "source": [ + "sequence_length=500\n", + "features=np.zeros((len(encoded_reviews), sequence_length), dtype=int)\n", + "for i, review in enumerate(encoded_reviews):\n", + " review_len=len(review)\n", + " if (review_len<=sequence_length):\n", + " zeros=list(np.zeros(sequence_length-review_len))\n", + " new=zeros+review\n", + " else:\n", + " new=review[:sequence_length]\n", + "features[i,:]=np.array(new)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "q-njY7CiKk4T", + "colab_type": "code", + "colab": {} + }, + "source": [ + "labels=train_data['label']" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "KSkcQMknKF-k", + "colab_type": "code", + "outputId": "f00d2da4-ec29-4820-a807-6b70e703248b", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + } + }, + "source": [ + "train_x=features[:int(0.8*len(features))]\n", + "train_y=labels[:int(0.8*len(features))]\n", + "valid_x=features[int(0.8*len(features)):]\n", + "valid_y=labels[int(0.8*len(features)):]\n", + "train_y=train_y.to_numpy()\n", + "valid_y=valid_y.to_numpy()\n", + "print(type(train_x), len(valid_y))\n", + "print(len(train_data))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + " 202155\n", + "1010773\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zTy3ma7HNomd", + "colab_type": "code", + "colab": {} + }, + "source": [ + "traindata = TensorDataset(torch.from_numpy(train_x), torch.from_numpy(train_y))\n", + "validdata = TensorDataset(torch.from_numpy(valid_x), torch.from_numpy(valid_y))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "FFDDBYvtXvzX", + "colab_type": "code", + "colab": {} + }, + "source": [ + "batch_size=50\n", + "train_loader=DataLoader(traindata, batch_size=batch_size, shuffle=True)\n", + "valid_loader=DataLoader(validdata, batch_size=batch_size, shuffle=True)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "KnhnFbL0XwFr", + "colab_type": "code", + "outputId": "2d496352-0ae6-4978-d1da-073ecda2a5f2", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 272 + } + }, + "source": [ + "dataiter = iter(train_loader)\n", + "sample_x, sample_y = dataiter.next()\n", + "print('Sample input size: ', sample_x.size()) # batch_size, seq_length\n", + "print('Sample input: \\n', sample_x)\n", + "print()\n", + "print('Sample label size: ', sample_y.size()) # batch_size\n", + "print('Sample label: \\n', sample_y)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Sample input size: torch.Size([50, 500])\n", + "Sample input: \n", + " tensor([[0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " ...,\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0],\n", + " [0, 0, 0, ..., 0, 0, 0]])\n", + "\n", + "Sample label size: torch.Size([50])\n", + "Sample label: \n", + " tensor([1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0,\n", + " 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,\n", + " 0, 1])\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lov_X4LhZZSP", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import torch.nn as nn\n", + "\n", + "class SentimentLSTM(nn.Module):\n", + " \n", + "\n", + " def __init__(self, vocab_size, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5):\n", + " \"\"\"\n", + " Initialize the model by setting up the layers.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.output_size = output_size\n", + " self.n_layers = n_layers\n", + " self.hidden_dim = hidden_dim\n", + " \n", + " \n", + " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", + " self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, \n", + " dropout=drop_prob, batch_first=True)\n", + " \n", + " \n", + " self.dropout = nn.Dropout(0.3)\n", + " \n", + " \n", + " self.fc = nn.Linear(hidden_dim, output_size)\n", + " self.sig = nn.Sigmoid()\n", + " \n", + "\n", + " def forward(self, x, hidden):\n", + " \n", + " batch_size = x.size(0)\n", + "\n", + " \n", + " embeds = self.embedding(x)\n", + " lstm_out, hidden = self.lstm(embeds, hidden)\n", + " \n", + " \n", + " lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)\n", + " out = self.dropout(lstm_out)\n", + " out = self.fc(out)\n", + " \n", + " sig_out = self.sig(out)\n", + " \n", + " sig_out = sig_out.view(batch_size, -1)\n", + " sig_out = sig_out[:, -1] \n", + " return sig_out, hidden\n", + " \n", + " \n", + " def init_hidden(self, batch_size):\n", + " \n", + " weight = next(self.parameters()).data\n", + " \n", + " if (train_on_gpu):\n", + " hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda(),\n", + " weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda())\n", + " else:\n", + " hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_(),\n", + " weight.new(self.n_layers, batch_size, self.hidden_dim).zero_())\n", + " \n", + " return hidden\n", + " " + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "3IoUHKG6bkGb", + "colab_type": "code", + "outputId": "207f1a1b-3769-4bac-8a5f-2a3418ad4048", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 136 + } + }, + "source": [ + "vocab_size = len(vocab_to_int)+1 # +1 for the 0 padding\n", + "output_size = 1\n", + "embedding_dim = 400\n", + "hidden_dim = 256\n", + "n_layers = 2\n", + "net = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)\n", + "print(net)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "SentimentLSTM(\n", + " (embedding): Embedding(206968, 400)\n", + " (lstm): LSTM(400, 256, num_layers=2, batch_first=True, dropout=0.5)\n", + " (dropout): Dropout(p=0.3)\n", + " (fc): Linear(in_features=256, out_features=1, bias=True)\n", + " (sig): Sigmoid()\n", + ")\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "f26FXAbUeS72", + "colab_type": "code", + "outputId": "6191d34f-f788-415c-9e8d-6e8ed84dffbb", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 394 + } + }, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "train_on_gpu = torch.cuda.is_available()\n", + "lr=0.001\n", + "\n", + "criterion = nn.BCELoss()\n", + "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n", + "\n", + "epochs = 4 \n", + "\n", + "counter = 0\n", + "print_every = 100\n", + "clip=5\n", + "if(train_on_gpu):\n", + " net.cuda()\n", + "\n", + "net.train()\n", + "for e in range(epochs):\n", + " \n", + " h = net.init_hidden(batch_size)\n", + " for inputs, labels in train_loader:\n", + " counter += 1\n", + "\n", + " if(train_on_gpu):\n", + " inputs, labels = inputs.cuda(), labels.cuda()\n", + "\n", + " h = tuple([each.data for each in h])\n", + "\n", + " # zero accumulated gradients\n", + " net.zero_grad()\n", + "\n", + " # get the output from the model\n", + " inputs = inputs.type(torch.cuda.LongTensor)\n", + " output,h = net(inputs, h)\n", + " print(output.shape)\n", + " # calculate the loss and perform backprop\n", + " loss = criterion(output.squeeze(), labels.float())\n", + " loss.backward(retain_graph=True)\n", + " \n", + " nn.utils.clip_grad_norm_(net.parameters(), clip)\n", + " optimizer.step()\n", + "\n", + " \n", + " if counter % print_every == 0:\n", + " \n", + " val_h = net.init_hidden(batch_size)\n", + " val_losses = []\n", + " net.eval()\n", + " for inputs, labels in valid_loader:\n", + "\n", + " \n", + " val_h = tuple([each.data for each in val_h])\n", + "\n", + " if(train_on_gpu):\n", + " inputs, labels = inputs.cuda(), labels.cuda()\n", + "\n", + " inputs = inputs.type(torch.cuda.LongTensor)\n", + " output, val_h = net(inputs, val_h)\n", + " val_loss = criterion(output.squeeze(), labels.float())\n", + "\n", + " val_losses.append(val_loss.item())\n", + "\n", + " net.train()\n", + " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n", + " \"Step: {}...\".format(counter),\n", + " \"Loss: {:.6f}...\".format(loss.item()),\n", + " \"Val Loss: {:.6f}\".format(np.mean(val_losses)))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "error", + "ename": "RuntimeError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_h\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_h\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0mval_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 491\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 494\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, hidden)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0membeds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mlstm_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlstm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 491\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 494\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 557\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_packed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 559\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 560\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 561\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward_tensor\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 537\u001b[0m \u001b[0munsorted_indices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 539\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_batch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 540\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 541\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute_hidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munsorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, input, hx, batch_sizes, max_batch_size, sorted_indices)\u001b[0m\n\u001b[1;32m 517\u001b[0m \u001b[0mhx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute_hidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 518\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 519\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_forward_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 520\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_sizes\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 521\u001b[0m result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mcheck_forward_args\u001b[0;34m(self, input, hidden, batch_sizes)\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m self.check_hidden_size(hidden[0], expected_hidden_size,\n\u001b[0;32m--> 494\u001b[0;31m 'Expected hidden[0] size {}, got {}')\n\u001b[0m\u001b[1;32m 495\u001b[0m self.check_hidden_size(hidden[1], expected_hidden_size,\n\u001b[1;32m 496\u001b[0m 'Expected hidden[1] size {}, got {}')\n", + "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mcheck_hidden_size\u001b[0;34m(self, hx, expected_hidden_size, msg)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;31m# type: (Tensor, Tuple[int, int, int], str) -> None\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mexpected_hidden_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 172\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpected_hidden_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcheck_forward_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Expected hidden[0] size (2, 5, 256), got (2, 50, 256)" + ] + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "eB_CGTGET05V", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "BSheCkROT095", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "M3CXZzIET1Iz", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "0vfJvTiaT1O2", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "UeGe55t6T1Sk", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "tWbk1HpzT1W0", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "5cyxIOvlT1bi", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "AhywpMP7T1gU", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "_MXann70T1mB", + "colab_type": "code", + "colab": {} + }, + "source": [ + "" + ], + "execution_count": 0, + "outputs": [] + } + ] +} \ No newline at end of file