diff --git a/SarcasmDetector/README.md b/SarcasmDetector/README.md
new file mode 100644
index 00000000..54a6889c
--- /dev/null
+++ b/SarcasmDetector/README.md
@@ -0,0 +1,14 @@
+# Sarcasm Detector
+
+tldr? A text classifier that detects sarcasm.
+
+As is well known to fams and fans of The Big Bang Theory, Sheldon Cooper, one the most brilliant minds of 21st century American Entertainment, sadly struggles to detect sarcasm in casual conversation. So a bunch of Cool Women in Tech decided to partner up and solve this pressing problem in modern pop culture.
+
+In this repo we document this journey for the reference of techies undertaking a similar quest in the near or far future! We hope you enjoy the process as much as we did, as hope to make it easier for you to follow along! ^_^
+
+
+>contributors
+
+>technical details
+
+>
diff --git a/SaveSheldon/sarcasm.ipynb b/SaveSheldon/sarcasm.ipynb
new file mode 100644
index 00000000..86c5bfac
--- /dev/null
+++ b/SaveSheldon/sarcasm.ipynb
@@ -0,0 +1,1772 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "sarcasm.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "collapsed_sections": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "BHq8guDSTy8q",
+ "colab_type": "code",
+ "outputId": "66d42816-4aba-487a-933a-2d8207b886ee",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 122
+ }
+ },
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code\n",
+ "\n",
+ "Enter your authorization code:\n",
+ "··········\n",
+ "Mounted at /content/drive\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "SimktZGaTz5H",
+ "colab_type": "code",
+ "outputId": "889df6a8-5791-4a1b-c15f-6a1228997239",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ }
+ },
+ "source": [
+ "!unzip '/content/drive/My Drive/train-balanced-sarcasm.csv.zip'"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Archive: /content/drive/My Drive/train-balanced-sarcasm.csv.zip\n",
+ " inflating: train-balanced-sarcasm.csv \n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1JuN0ZYyW1SL",
+ "colab_type": "code",
+ "outputId": "c00ea784-5380-484c-e730-24e78d3498e2",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ }
+ },
+ "source": [
+ "import numpy as np # linear algebra\n",
+ "import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n",
+ "from plotly import tools\n",
+ "import plotly.offline as py\n",
+ "py.init_notebook_mode(connected=True)\n",
+ "import plotly.graph_objs as go\n",
+ "from collections import defaultdict\n",
+ "from matplotlib import pyplot as plt\n",
+ "%matplotlib inline\n",
+ "import re\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader, TensorDataset"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/vnd.plotly.v1+html": "",
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "2lQDgLiVTz60",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "train_data=pd.read_csv('/content/train-balanced-sarcasm.csv')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Y1WgCA09Tz93",
+ "colab_type": "code",
+ "outputId": "49d327d5-e483-44c3-9ae6-c829ef389bf1",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 374
+ }
+ },
+ "source": [
+ "train_data.head()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " label | \n",
+ " comment | \n",
+ " author | \n",
+ " subreddit | \n",
+ " score | \n",
+ " ups | \n",
+ " downs | \n",
+ " date | \n",
+ " created_utc | \n",
+ " parent_comment | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0 | \n",
+ " NC and NH. | \n",
+ " Trumpbart | \n",
+ " politics | \n",
+ " 2 | \n",
+ " -1 | \n",
+ " -1 | \n",
+ " 2016-10 | \n",
+ " 2016-10-16 23:55:23 | \n",
+ " Yeah, I get that argument. At this point, I'd ... | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 0 | \n",
+ " You do know west teams play against west teams... | \n",
+ " Shbshb906 | \n",
+ " nba | \n",
+ " -4 | \n",
+ " -1 | \n",
+ " -1 | \n",
+ " 2016-11 | \n",
+ " 2016-11-01 00:24:10 | \n",
+ " The blazers and Mavericks (The wests 5 and 6 s... | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0 | \n",
+ " They were underdogs earlier today, but since G... | \n",
+ " Creepeth | \n",
+ " nfl | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 0 | \n",
+ " 2016-09 | \n",
+ " 2016-09-22 21:45:37 | \n",
+ " They're favored to win. | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 0 | \n",
+ " This meme isn't funny none of the \"new york ni... | \n",
+ " icebrotha | \n",
+ " BlackPeopleTwitter | \n",
+ " -8 | \n",
+ " -1 | \n",
+ " -1 | \n",
+ " 2016-10 | \n",
+ " 2016-10-18 21:03:47 | \n",
+ " deadass don't kill my buzz | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 0 | \n",
+ " I could use one of those tools. | \n",
+ " cush2push | \n",
+ " MaddenUltimateTeam | \n",
+ " 6 | \n",
+ " -1 | \n",
+ " -1 | \n",
+ " 2016-12 | \n",
+ " 2016-12-30 17:00:13 | \n",
+ " Yep can confirm I saw the tool they use for th... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " label ... parent_comment\n",
+ "0 0 ... Yeah, I get that argument. At this point, I'd ...\n",
+ "1 0 ... The blazers and Mavericks (The wests 5 and 6 s...\n",
+ "2 0 ... They're favored to win.\n",
+ "3 0 ... deadass don't kill my buzz\n",
+ "4 0 ... Yep can confirm I saw the tool they use for th...\n",
+ "\n",
+ "[5 rows x 10 columns]"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 18
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ogouztYRUChv",
+ "colab_type": "code",
+ "outputId": "fc2e8096-111d-41b9-8556-8b7ba587d6c1",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 272
+ }
+ },
+ "source": [
+ "train_data.info()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "RangeIndex: 1010826 entries, 0 to 1010825\n",
+ "Data columns (total 10 columns):\n",
+ "label 1010826 non-null int64\n",
+ "comment 1010773 non-null object\n",
+ "author 1010826 non-null object\n",
+ "subreddit 1010826 non-null object\n",
+ "score 1010826 non-null int64\n",
+ "ups 1010826 non-null int64\n",
+ "downs 1010826 non-null int64\n",
+ "date 1010826 non-null object\n",
+ "created_utc 1010826 non-null object\n",
+ "parent_comment 1010826 non-null object\n",
+ "dtypes: int64(4), object(6)\n",
+ "memory usage: 77.1+ MB\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Yil_0nYRTz_v",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "train_data.dropna(subset=['comment'], inplace=True)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "7lbG1vVsT0C8",
+ "colab_type": "code",
+ "outputId": "1ea56472-0789-4d02-f1e3-beba074c9d63",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 272
+ }
+ },
+ "source": [
+ "train_data.info()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Int64Index: 1010773 entries, 0 to 1010825\n",
+ "Data columns (total 10 columns):\n",
+ "label 1010773 non-null int64\n",
+ "comment 1010773 non-null object\n",
+ "author 1010773 non-null object\n",
+ "subreddit 1010773 non-null object\n",
+ "score 1010773 non-null int64\n",
+ "ups 1010773 non-null int64\n",
+ "downs 1010773 non-null int64\n",
+ "date 1010773 non-null object\n",
+ "created_utc 1010773 non-null object\n",
+ "parent_comment 1010773 non-null object\n",
+ "dtypes: int64(4), object(6)\n",
+ "memory usage: 84.8+ MB\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "a2l6PiF4T0He",
+ "colab_type": "code",
+ "outputId": "f4141444-733f-4964-87ac-a4299e6f440c",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "train_data['label'].unique()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "array([0, 1])"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 21
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1RmuDIUxT0MM",
+ "colab_type": "code",
+ "outputId": "237c9949-4516-4305-f291-735e0ca1de32",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ }
+ },
+ "source": [
+ "train_data.loc[train_data['label'] == 1, 'comment'].str.len()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "33 35\n",
+ "44 145\n",
+ "45 107\n",
+ "66 67\n",
+ "69 65\n",
+ "74 40\n",
+ "76 116\n",
+ "160 30\n",
+ "169 56\n",
+ "193 182\n",
+ "194 58\n",
+ "216 188\n",
+ "230 56\n",
+ "255 100\n",
+ "274 22\n",
+ "278 60\n",
+ "305 33\n",
+ "336 42\n",
+ "346 34\n",
+ "359 28\n",
+ "370 39\n",
+ "443 96\n",
+ "457 28\n",
+ "479 139\n",
+ "487 43\n",
+ "497 35\n",
+ "508 34\n",
+ "526 26\n",
+ "556 48\n",
+ "575 5\n",
+ " ... \n",
+ "1010794 69\n",
+ "1010795 49\n",
+ "1010796 5\n",
+ "1010797 65\n",
+ "1010798 20\n",
+ "1010799 9\n",
+ "1010800 64\n",
+ "1010801 143\n",
+ "1010802 184\n",
+ "1010803 21\n",
+ "1010805 38\n",
+ "1010806 20\n",
+ "1010807 219\n",
+ "1010808 53\n",
+ "1010809 338\n",
+ "1010810 23\n",
+ "1010811 17\n",
+ "1010812 65\n",
+ "1010813 25\n",
+ "1010814 131\n",
+ "1010816 13\n",
+ "1010817 77\n",
+ "1010818 34\n",
+ "1010819 38\n",
+ "1010820 2\n",
+ "1010821 92\n",
+ "1010822 34\n",
+ "1010823 66\n",
+ "1010824 53\n",
+ "1010825 72\n",
+ "Name: comment, Length: 505368, dtype: int64"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 22
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "4h-5SFSaT0OM",
+ "colab_type": "code",
+ "outputId": "0b8c922d-8b27-4bd6-aea5-6302a60b4fa1",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 785
+ }
+ },
+ "source": [
+ "cnt_label = train_data['label'].value_counts()\n",
+ "print(cnt_label)\n",
+ "labels = (np.array(cnt_label.index))\n",
+ "sizes=(np.array((cnt_label/cnt_label.sum())*100))\n",
+ "print(sizes)\n",
+ "trace = go.Pie(labels=labels, values=sizes)\n",
+ "layout = go.Layout(title='Label Distribution', font=dict(size=15),\n",
+ " width=700, height=700)\n",
+ "\n",
+ "data = [trace]\n",
+ "fig = go.Figure(data=data, layout=layout)\n",
+ "py.iplot(fig, filename=\"LabelDistribution\")"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "0 505405\n",
+ "1 505368\n",
+ "Name: label, dtype: int64\n",
+ "[50.00183028 49.99816972]\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/vnd.plotly.v1+html": "",
+ "text/html": [
+ ""
+ ],
+ "application/vnd.plotly.v1+json": {
+ "layout": {
+ "title": {
+ "text": "Label Distribution"
+ },
+ "font": {
+ "size": 15
+ },
+ "width": 700,
+ "height": 700
+ },
+ "config": {
+ "plotlyServerURL": "https://plot.ly",
+ "linkText": "Export to plot.ly",
+ "showLink": false
+ },
+ "data": [
+ {
+ "type": "pie",
+ "labels": [
+ 0,
+ 1
+ ],
+ "values": [
+ 50.001830282368054,
+ 49.99816971763195
+ ],
+ "uid": "7e84568b-b29c-4213-a806-387dca1a08ed"
+ }
+ ]
+ }
+ },
+ "metadata": {
+ "tags": []
+ }
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "mbtWNnH3bype",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "def horizontal_bar_chart(df, color):\n",
+ " trace = go.Bar(y = df[\"word\"].values[::-1],\n",
+ " x = df[\"wordcount\"].values[::-1],\n",
+ " showlegend=False,\n",
+ " orientation='h',\n",
+ " marker=dict(color=color))\n",
+ " return trace"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_1TAGy0wT0Ra",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "train1_df = train_data[train_data['label']==1]\n",
+ "train0_df = train_data[train_data['label']==0]\n",
+ "from wordcloud import WordCloud, STOPWORDS\n",
+ "def generate_ngrams(text, n_gram=1):\n",
+ " token = [token for token in text.lower().split(\" \") if token!= \"\" if token not\n",
+ " in STOPWORDS]\n",
+ " \n",
+ " ngrams = zip(*[token[i:] for i in range(n_gram)])\n",
+ " return [\" \".join(ngram) for ngram in ngrams]\n",
+ "\n",
+ "freq_dict = defaultdict(int)\n",
+ "for sent in train0_df[\"comment\"]:\n",
+ " for word in generate_ngrams(sent):\n",
+ " freq_dict[word] += 1\n",
+ "fd_sorted = pd.DataFrame(sorted(freq_dict.items(), key=lambda x: x[1])[::-1])\n",
+ "fd_sorted.columns = [\"word\", \"wordcount\"]\n",
+ "trace1 = horizontal_bar_chart(fd_sorted.head(50), 'blue')\n",
+ " "
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_agqrum1T0WM",
+ "colab_type": "code",
+ "outputId": "2b002a72-3ff2-42f4-adbf-120abbc8967a",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 390
+ }
+ },
+ "source": [
+ "sub_df = train_data.groupby('subreddit')['label'].agg([np.size, np.mean, np.sum])\n",
+ "sub_df.sort_values(by='sum', ascending=False).head(10)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " size | \n",
+ " mean | \n",
+ " sum | \n",
+ "
\n",
+ " \n",
+ " | subreddit | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | AskReddit | \n",
+ " 65674 | \n",
+ " 0.401453 | \n",
+ " 26365 | \n",
+ "
\n",
+ " \n",
+ " | politics | \n",
+ " 39493 | \n",
+ " 0.605348 | \n",
+ " 23907 | \n",
+ "
\n",
+ " \n",
+ " | worldnews | \n",
+ " 26376 | \n",
+ " 0.642516 | \n",
+ " 16947 | \n",
+ "
\n",
+ " \n",
+ " | leagueoflegends | \n",
+ " 21034 | \n",
+ " 0.542312 | \n",
+ " 11407 | \n",
+ "
\n",
+ " \n",
+ " | pcmasterrace | \n",
+ " 18987 | \n",
+ " 0.566651 | \n",
+ " 10759 | \n",
+ "
\n",
+ " \n",
+ " | news | \n",
+ " 16891 | \n",
+ " 0.603457 | \n",
+ " 10193 | \n",
+ "
\n",
+ " \n",
+ " | funny | \n",
+ " 17939 | \n",
+ " 0.451474 | \n",
+ " 8099 | \n",
+ "
\n",
+ " \n",
+ " | pics | \n",
+ " 16152 | \n",
+ " 0.484336 | \n",
+ " 7823 | \n",
+ "
\n",
+ " \n",
+ " | todayilearned | \n",
+ " 14159 | \n",
+ " 0.547567 | \n",
+ " 7753 | \n",
+ "
\n",
+ " \n",
+ " | GlobalOffensive | \n",
+ " 13738 | \n",
+ " 0.552045 | \n",
+ " 7584 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " size mean sum\n",
+ "subreddit \n",
+ "AskReddit 65674 0.401453 26365\n",
+ "politics 39493 0.605348 23907\n",
+ "worldnews 26376 0.642516 16947\n",
+ "leagueoflegends 21034 0.542312 11407\n",
+ "pcmasterrace 18987 0.566651 10759\n",
+ "news 16891 0.603457 10193\n",
+ "funny 17939 0.451474 8099\n",
+ "pics 16152 0.484336 7823\n",
+ "todayilearned 14159 0.547567 7753\n",
+ "GlobalOffensive 13738 0.552045 7584"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 26
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "e4DWndGiT0Yv",
+ "colab_type": "code",
+ "outputId": "00564040-bd8b-4a07-e125-0c21b10c5536",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 390
+ }
+ },
+ "source": [
+ "sub_df[sub_df['size'] > 1000].sort_values(by='mean', ascending=False).head(10)\n"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " size | \n",
+ " mean | \n",
+ " sum | \n",
+ "
\n",
+ " \n",
+ " | subreddit | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | creepyPMs | \n",
+ " 5466 | \n",
+ " 0.784303 | \n",
+ " 4287 | \n",
+ "
\n",
+ " \n",
+ " | MensRights | \n",
+ " 3355 | \n",
+ " 0.680775 | \n",
+ " 2284 | \n",
+ "
\n",
+ " \n",
+ " | ShitRedditSays | \n",
+ " 1284 | \n",
+ " 0.661994 | \n",
+ " 850 | \n",
+ "
\n",
+ " \n",
+ " | worldnews | \n",
+ " 26376 | \n",
+ " 0.642516 | \n",
+ " 16947 | \n",
+ "
\n",
+ " \n",
+ " | Libertarian | \n",
+ " 2562 | \n",
+ " 0.640125 | \n",
+ " 1640 | \n",
+ "
\n",
+ " \n",
+ " | atheism | \n",
+ " 7377 | \n",
+ " 0.639555 | \n",
+ " 4718 | \n",
+ "
\n",
+ " \n",
+ " | Conservative | \n",
+ " 1881 | \n",
+ " 0.639553 | \n",
+ " 1203 | \n",
+ "
\n",
+ " \n",
+ " | TwoXChromosomes | \n",
+ " 1560 | \n",
+ " 0.632692 | \n",
+ " 987 | \n",
+ "
\n",
+ " \n",
+ " | fatlogic | \n",
+ " 2356 | \n",
+ " 0.623090 | \n",
+ " 1468 | \n",
+ "
\n",
+ " \n",
+ " | facepalm | \n",
+ " 1268 | \n",
+ " 0.617508 | \n",
+ " 783 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " size mean sum\n",
+ "subreddit \n",
+ "creepyPMs 5466 0.784303 4287\n",
+ "MensRights 3355 0.680775 2284\n",
+ "ShitRedditSays 1284 0.661994 850\n",
+ "worldnews 26376 0.642516 16947\n",
+ "Libertarian 2562 0.640125 1640\n",
+ "atheism 7377 0.639555 4718\n",
+ "Conservative 1881 0.639553 1203\n",
+ "TwoXChromosomes 1560 0.632692 987\n",
+ "fatlogic 2356 0.623090 1468\n",
+ "facepalm 1268 0.617508 783"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 27
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "np0jP520T0bL",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "train_data['comment']=train_data['comment'].str.lower()\n",
+ "train_data['parent_comment']=(train_data['parent_comment']).str.lower()\n",
+ "train_data['author']=(train_data['comment']).str.lower()\n",
+ "train_data['subreddit']=(train_data['subreddit']).str.lower()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "VFPRDEpjT0gQ",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "train_data['comment']=train_data['comment'].replace('[^\\w\\s]','', regex=True)\n",
+ "train_data['parent_comment']=train_data['parent_comment'].replace('[^\\w\\s]','', regex=True)\n",
+ "train_data['author']=train_data['author'].replace('[^\\w\\s]','', regex=True)\n",
+ "train_data['subreddit']=train_data['subreddit'].replace('[^\\w\\s]','', regex=True)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "H-SaNb4U_jrg",
+ "colab_type": "code",
+ "outputId": "51b60e43-6db3-459b-91f4-af992a1e1060",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "train_data['parent_comment'][4]"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "'yep can confirm i saw the tool they use for that it was made by our boy easports_mut'"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 18
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "83lTIy9sSP4f",
+ "colab_type": "code",
+ "outputId": "5fbff3cd-e2c0-493d-b22a-c68bad7a1dfe",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 261
+ }
+ },
+ "source": [
+ "train_data.head(3)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " label | \n",
+ " comment | \n",
+ " author | \n",
+ " subreddit | \n",
+ " score | \n",
+ " ups | \n",
+ " downs | \n",
+ " date | \n",
+ " created_utc | \n",
+ " parent_comment | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 0 | \n",
+ " nc and nh | \n",
+ " nc and nh | \n",
+ " politics | \n",
+ " 2 | \n",
+ " -1 | \n",
+ " -1 | \n",
+ " 2016-10 | \n",
+ " 2016-10-16 23:55:23 | \n",
+ " yeah i get that argument at this point id pref... | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 0 | \n",
+ " you do know west teams play against west teams... | \n",
+ " you do know west teams play against west teams... | \n",
+ " nba | \n",
+ " -4 | \n",
+ " -1 | \n",
+ " -1 | \n",
+ " 2016-11 | \n",
+ " 2016-11-01 00:24:10 | \n",
+ " the blazers and mavericks the wests 5 and 6 se... | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 0 | \n",
+ " they were underdogs earlier today but since gr... | \n",
+ " they were underdogs earlier today but since gr... | \n",
+ " nfl | \n",
+ " 3 | \n",
+ " 3 | \n",
+ " 0 | \n",
+ " 2016-09 | \n",
+ " 2016-09-22 21:45:37 | \n",
+ " theyre favored to win | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " label ... parent_comment\n",
+ "0 0 ... yeah i get that argument at this point id pref...\n",
+ "1 0 ... the blazers and mavericks the wests 5 and 6 se...\n",
+ "2 0 ... theyre favored to win\n",
+ "\n",
+ "[3 rows x 10 columns]"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 19
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "tKQREGVsT0k7",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "train_data.drop(['author','subreddit','score','ups','downs','date','created_utc'], axis=1, inplace=True)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "PPwJAPMhbMOU",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "train_data.drop(['parent_comment'], axis=1, inplace=True)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "R3jSvxvG-5f5",
+ "colab_type": "code",
+ "outputId": "9c99724b-b742-4bbb-f706-349e9dc1e872",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "from string import punctuation\n",
+ "print(punctuation)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "IadaqnSA-_sI",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "all_reviews=list()\n",
+ "for text in train_data['comment']:\n",
+ " text = text.lower()\n",
+ " text = \"\".join([ch for ch in text if ch not in punctuation])\n",
+ " all_reviews.append(text)\n",
+ "all_text = \" \".join(all_reviews)\n",
+ "all_words = all_text.split()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "sqMdU79Z_N_R",
+ "colab_type": "code",
+ "outputId": "6937b862-1f5c-4284-a52e-e3470f77a251",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "print(all_words[:10])"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "['nc', 'and', 'nh', 'you', 'do', 'know', 'west', 'teams', 'play', 'against']\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "lGKHqJNK_6-g",
+ "colab_type": "code",
+ "outputId": "87ce60ac-e1b6-414e-b5a6-368dd7a399b3",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ }
+ },
+ "source": [
+ "print(type(all_words))"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "7ic1z5Nh_OMi",
+ "colab_type": "code",
+ "outputId": "d43e82d8-8b15-4b07-8eaf-6e0549ed4f07",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 54
+ }
+ },
+ "source": [
+ "from collections import Counter \n",
+ "# Count all the words using Counter Method\n",
+ "count_words = Counter(all_words)\n",
+ "total_words=len(all_words)\n",
+ "(sorted_words)=count_words.most_common(total_words)\n",
+ "print(sorted_words[:10])\n",
+ "#print(\"Top ten occuring words : \"+sorted_words[:10])"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "[('the', 396228), ('a', 244958), ('to', 239184), ('i', 184355), ('and', 173669), ('you', 171489), ('is', 154657), ('of', 149812), ('that', 140694), ('it', 128142)]\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "0q4X41SsGzUn",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "vocab_to_int={w:i+1 for i,(w,c) in enumerate(sorted_words)}\n",
+ "#print(vocab_to_int)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "in-yWAc7GzSH",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "encoded_reviews=list()\n",
+ "for review in all_reviews:\n",
+ " encoded_review=list()\n",
+ " for word in review.split():\n",
+ " if word not in vocab_to_int.keys():\n",
+ " #if word is not available in vocab_to_int put 0 in that place\n",
+ " encoded_review.append(0)\n",
+ " else:\n",
+ " encoded_review.append(vocab_to_int[word])\n",
+ " encoded_reviews.append(encoded_review)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "qaK3zcKOIKJ-",
+ "colab_type": "code",
+ "outputId": "1d0be026-1b2b-431c-953d-59b192e8a536",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 422
+ }
+ },
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "%matplotlib inline\n",
+ "reviews_len = [len(x) for x in encoded_reviews]\n",
+ "pd.Series(reviews_len).hist()\n",
+ "plt.show()\n",
+ "pd.Series(reviews_len).describe()"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZMAAAD8CAYAAACyyUlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAEzNJREFUeJzt3X+s3fV93/Hna7ikLm2CgekK2Wh2\nV2uTG7SVXBFPqaKrMhlDp5lJbUSEhpeh+I+QLl2YFmf9gypRpGQazQJKkbzixVQolKWZbG0mnkdy\nVO0PCJBSzI9R3xJSbBloMYE6UZM5fe+P8/F2rn3vNT4fm2Pf+3xIR/d73t/P9/v5nrfO5eXz/X7v\nIVWFJEk9/takD0CSdOEzTCRJ3QwTSVI3w0SS1M0wkSR1M0wkSd0ME0lSN8NEktTNMJEkdVsx6QN4\np1xxxRW1du3asbb9wQ9+wCWXXHJ2D+gCZ0/msh+nsidzXaj9ePLJJ/+yqv726cYtmzBZu3YtTzzx\nxFjbDgYDZmZmzu4BXeDsyVz241T2ZK4LtR9Jvvd2xnmaS5LUzTCRJHUzTCRJ3QwTSVI3w0SS1M0w\nkSR1M0wkSd1OGyZJdiZ5LckzI7XLkuxPcrD9XNXqSXJ3ktkkTye5ZmSbrW38wSRbR+rvS3KgbXN3\nkow7hyRpMt7OJ5OvAJtPqm0HHqmq9cAj7TnADcD69tgG3AvDYADuBN4PXAvceSIc2piPjmy3eZw5\nJEmTc9q/gK+qP0qy9qTyFmCmLe8CBsCnWv3+qirg0SSXJrmyjd1fVUcBkuwHNicZAO+uqkdb/X7g\nJuDhM52jqo6c2Ut/+w4cfpN/sf2/n6vdL+qlz//qROaVpDMx7jWTqZH/eL8CTLXl1cDLI+MOtdpi\n9UPz1MeZQ5I0Id3fzVVVlaTOxsGc7TmSbGN4KoypqSkGg8FY80+thDuuPj7Wtr3GPeZz7dixY+ft\nsU2C/TiVPZlrqfdj3DB59cSppXYa67VWPwxcNTJuTasd5v+fsjpRH7T6mnnGjzPHKapqB7ADYHp6\nusb9krV7HtjNXQcm852YL90yM5F5T+dC/dK6c8V+nMqezLXU+zHuaa49wIk7srYCu0fqt7Y7rjYC\nb7ZTVfuATUlWtQvvm4B9bd1bSTa2u7huPWlfZzKHJGlCTvvP7SRfZfip4ookhxjelfV54KEktwHf\nAz7Uhu8FbgRmgR8CHwGoqqNJPgs83sZ95sTFeOBjDO8YW8nwwvvDrX5Gc0iSJuft3M314QVWXTfP\n2AJuX2A/O4Gd89SfAN47T/31M51DkjQZ/gW8JKmbYSJJ6maYSJK6GSaSpG6GiSSpm2EiSepmmEiS\nuhkmkqRuhokkqZthIknqZphIkroZJpKkboaJJKmbYSJJ6maYSJK6GSaSpG6GiSSpm2EiSepmmEiS\nuhkmkqRuhokkqZthIknqZphIkroZJpKkboaJJKmbYSJJ6maYSJK6GSaSpG6GiSSpm2EiSepmmEiS\nuhkmkqRuXWGS5F8neTbJM0m+muSnk6xL8liS2SR/kOTiNvZd7flsW792ZD+fbvUXklw/Ut/carNJ\nto/U551DkjQZY4dJktXAvwKmq+q9wEXAzcAXgC9W1S8AbwC3tU1uA95o9S+2cSTZ0Lb7RWAz8LtJ\nLkpyEfBl4AZgA/DhNpZF5pAkTUDvaa4VwMokK4CfAY4AvwJ8ra3fBdzUlre057T11yVJqz9YVT+q\nqu8Cs8C17TFbVS9W1Y+BB4EtbZuF5pAkTcDYYVJVh4H/APw5wxB5E3gS+H5VHW/DDgGr2/Jq4OW2\n7fE2/vLR+knbLFS/fJE5JEkTsGLcDZOsYvipYh3wfeC/MDxNdd5Isg3YBjA1NcVgMBhrP1Mr4Y6r\nj59+4Dkw7jGfa8eOHTtvj20S7Mep7MlcS70fY4cJ8I+B71bVXwAk+TrwAeDSJCvaJ4c1wOE2/jBw\nFXConRZ7D/D6SP2E0W3mq7++yBxzVNUOYAfA9PR0zczMjPVC73lgN3cd6GnV+F66ZWYi857OYDBg\n3H4uRfbjVPZkrqXej55rJn8ObEzyM+06xnXAc8C3gF9rY7YCu9vynvactv6bVVWtfnO722sdsB74\nNvA4sL7duXUxw4v0e9o2C80hSZqAnmsmjzG8CP4d4EDb1w7gU8Ank8wyvL5xX9vkPuDyVv8ksL3t\n51ngIYZB9A3g9qr6SfvU8XFgH/A88FAbyyJzSJImoOvcTVXdCdx5UvlFhndinTz2r4FfX2A/nwM+\nN099L7B3nvq8c0iSJsO/gJckdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1\nM0wkSd0ME0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1\nM0wkSd0ME0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1K0rTJJc\nmuRrSf53kueT/KMklyXZn+Rg+7mqjU2Su5PMJnk6yTUj+9naxh9MsnWk/r4kB9o2dydJq887hyRp\nMno/mXwJ+EZV/X3gHwDPA9uBR6pqPfBIew5wA7C+PbYB98IwGIA7gfcD1wJ3joTDvcBHR7bb3OoL\nzSFJmoCxwyTJe4APAvcBVNWPq+r7wBZgVxu2C7ipLW8B7q+hR4FLk1wJXA/sr6qjVfUGsB/Y3Na9\nu6oeraoC7j9pX/PNIUmagJ5PJuuAvwD+c5I/TvJ7SS4BpqrqSBvzCjDVllcDL49sf6jVFqsfmqfO\nInNIkiZgRee21wC/UVWPJfkSJ51uqqpKUj0HeDqLzZFkG8NTakxNTTEYDMaaY2ol3HH18bGPsce4\nx3yuHTt27Lw9tkmwH6eyJ3Mt9X70hMkh4FBVPdaef41hmLya5MqqOtJOVb3W1h8GrhrZfk2rHQZm\nTqoPWn3NPONZZI45qmoHsANgenq6ZmZm5ht2Wvc8sJu7DvS0anwv3TIzkXlPZzAYMG4/lyL7cSp7\nMtdS78fYp7mq6hXg5SR/r5WuA54D9gAn7sjaCuxuy3uAW9tdXRuBN9upqn3ApiSr2oX3TcC+tu6t\nJBvbXVy3nrSv+eaQJE1A7z+3fwN4IMnFwIvARxgG1ENJbgO+B3yojd0L3AjMAj9sY6mqo0k+Czze\nxn2mqo625Y8BXwFWAg+3B8DnF5hDkjQBXWFSVU8B0/Osum6esQXcvsB+dgI756k/Abx3nvrr880h\nSZoM/wJektTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wkSd0M\nE0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wkSd0M\nE0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVK37jBJclGSP07y39rzdUke\nSzKb5A+SXNzq72rPZ9v6tSP7+HSrv5Dk+pH65labTbJ9pD7vHJKkyTgbn0w+ATw/8vwLwBer6heA\nN4DbWv024I1W/2IbR5INwM3ALwKbgd9tAXUR8GXgBmAD8OE2drE5JEkT0BUmSdYAvwr8Xnse4FeA\nr7Uhu4Cb2vKW9py2/ro2fgvwYFX9qKq+C8wC17bHbFW9WFU/Bh4EtpxmDknSBPR+MvmPwL8F/qY9\nvxz4flUdb88PAavb8mrgZYC2/s02/v/VT9pmofpic0iSJmDFuBsm+SfAa1X1ZJKZs3dIZ0+SbcA2\ngKmpKQaDwVj7mVoJd1x9/PQDz4Fxj/lcO3bs2Hl7bJNgP05lT+Za6v0YO0yADwD/NMmNwE8D7wa+\nBFyaZEX75LAGONzGHwauAg4lWQG8B3h9pH7C6Dbz1V9fZI45qmoHsANgenq6ZmZmxnqh9zywm7sO\n9LRqfC/dMjOReU9nMBgwbj+XIvtxKnsy11Lvx9inuarq01W1pqrWMryA/s2qugX4FvBrbdhWYHdb\n3tOe09Z/s6qq1W9ud3utA9YD3wYeB9a3O7cubnPsadssNIckaQLOxd+ZfAr4ZJJZhtc37mv1+4DL\nW/2TwHaAqnoWeAh4DvgGcHtV/aR96vg4sI/h3WIPtbGLzSFJmoCzcu6mqgbAoC2/yPBOrJPH/DXw\n6wts/zngc/PU9wJ756nPO4ckaTL8C3hJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wk\nSd0ME0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wk\nSd0ME0lSN8NEktTNMJEkdTNMJEndDBNJUjfDRJLUzTCRJHUzTCRJ3QwTSVI3w0SS1M0wkSR1M0wk\nSd3GDpMkVyX5VpLnkjyb5BOtflmS/UkOtp+rWj1J7k4ym+TpJNeM7GtrG38wydaR+vuSHGjb3J0k\ni80hSZqMnk8mx4E7qmoDsBG4PckGYDvwSFWtBx5pzwFuANa3xzbgXhgGA3An8H7gWuDOkXC4F/jo\nyHabW32hOSRJEzB2mFTVkar6Tlv+K+B5YDWwBdjVhu0CbmrLW4D7a+hR4NIkVwLXA/ur6mhVvQHs\nBza3de+uqkerqoD7T9rXfHNIkibgrFwzSbIW+CXgMWCqqo60Va8AU215NfDyyGaHWm2x+qF56iwy\nhyRpAlb07iDJzwJ/CPxmVb3VLmsAUFWVpHrnWMxicyTZxvCUGlNTUwwGg7HmmFoJd1x9fOxj7DHu\nMZ9rx44dO2+PbRLsx6nsyVxLvR9dYZLkpxgGyQNV9fVWfjXJlVV1pJ2qeq3VDwNXjWy+ptUOAzMn\n1Qetvmae8YvNMUdV7QB2AExPT9fMzMx8w07rngd2c9eB7twdy0u3zExk3tMZDAaM28+lyH6cyp7M\ntdT70XM3V4D7gOer6ndGVu0BTtyRtRXYPVK/td3VtRF4s52q2gdsSrKqXXjfBOxr695KsrHNdetJ\n+5pvDknSBPT8c/sDwD8HDiR5qtX+HfB54KEktwHfAz7U1u0FbgRmgR8CHwGoqqNJPgs83sZ9pqqO\ntuWPAV8BVgIPtweLzCFJmoCxw6Sq/heQBVZfN8/4Am5fYF87gZ3z1J8A3jtP/fX55pAkTYZ/AS9J\n6maYSJK6GSaSpG6GiSSpm2EiSepmmEiSuhkmkqRuhokkqZthIknqZphIkroZJpKkboaJJKmbYSJJ\n6maYSJK6GSaSpG6GiSSpm2EiSepmmEiSuhkmkqRuhokkqZthIknqZphIkroZJpKkboaJJKmbYSJJ\n6maYSJK6GSaSpG6GiSSpm2EiSepmmEiSuhkmkqRuhokkqdsFGyZJNid5Iclsku2TPh5JWs4uyDBJ\nchHwZeAGYAPw4SQbJntUkrR8XZBhAlwLzFbVi1X1Y+BBYMuEj0mSlq0LNUxWAy+PPD/UapKkCVgx\n6QM4l5JsA7a1p8eSvDDmrq4A/vLsHNWZyRcmMevbMrGenKfsx6nsyVwXaj/+ztsZdKGGyWHgqpHn\na1ptjqraAezonSzJE1U13bufpcSezGU/TmVP5lrq/bhQT3M9DqxPsi7JxcDNwJ4JH5MkLVsX5CeT\nqjqe5OPAPuAiYGdVPTvhw5KkZeuCDBOAqtoL7H2Hpus+VbYE2ZO57Mep7MlcS7ofqapJH4Mk6QJ3\noV4zkSSdRwyT01iuX9uS5KUkB5I8leSJVrssyf4kB9vPVa2eJHe3Hj2d5JrJHv3ZkWRnkteSPDNS\nO+MeJNnaxh9MsnUSr+VsWKAfv53kcHufPJXkxpF1n279eCHJ9SP1JfE7leSqJN9K8lySZ5N8otWX\n53ukqnws8GB4cf/PgJ8HLgb+BNgw6eN6h177S8AVJ9X+PbC9LW8HvtCWbwQeBgJsBB6b9PGfpR58\nELgGeGbcHgCXAS+2n6va8qpJv7az2I/fBv7NPGM3tN+XdwHr2u/RRUvpdwq4ErimLf8c8KftdS/L\n94ifTBbn17bMtQXY1ZZ3ATeN1O+voUeBS5NcOYkDPJuq6o+AoyeVz7QH1wP7q+poVb0B7Ac2n/uj\nP/sW6MdCtgAPVtWPquq7wCzD36cl8ztVVUeq6jtt+a+A5xl+E8eyfI8YJotbzl/bUsD/SPJk+yYB\ngKmqOtKWXwGm2vJy6tOZ9mA59Obj7bTNzhOndFhm/UiyFvgl4DGW6XvEMNFCfrmqrmH4zcy3J/ng\n6Moafj5f1rcC2gMA7gX+LvAPgSPAXZM9nHdekp8F/hD4zap6a3TdcnqPGCaLe1tf27IUVdXh9vM1\n4L8yPD3x6onTV+3na234curTmfZgSfemql6tqp9U1d8A/4nh+wSWST+S/BTDIHmgqr7eysvyPWKY\nLG5Zfm1LkkuS/NyJZWAT8AzD137iTpOtwO62vAe4td2tshF4c+Rj/lJzpj3YB2xKsqqdAtrUakvC\nSdfG/hnD9wkM+3FzknclWQesB77NEvqdShLgPuD5qvqdkVXL8z0y6TsAzvcHwzsw/pThHSi/Nenj\neYde888zvMvmT4BnT7xu4HLgEeAg8D+By1o9DP9nZX8GHACmJ/0azlIfvsrw1M3/YXge+7ZxegD8\nS4YXoGeBj0z6dZ3lfvx+e71PM/yP5ZUj43+r9eMF4IaR+pL4nQJ+meEprKeBp9rjxuX6HvEv4CVJ\n3TzNJUnqZphIkroZJpKkboaJJKmbYSJJ6maYSJK6GSaSpG6GiSSp2/8Fqu+LlacobiMAAAAASUVO\nRK5CYII=\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "count 1.010773e+06\n",
+ "mean 1.042197e+01\n",
+ "std 1.050239e+01\n",
+ "min 0.000000e+00\n",
+ "25% 5.000000e+00\n",
+ "50% 9.000000e+00\n",
+ "75% 1.400000e+01\n",
+ "max 2.222000e+03\n",
+ "dtype: float64"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 39
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "LBgg-Fe6JcF0",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "sequence_length=500\n",
+ "features=np.zeros((len(encoded_reviews), sequence_length), dtype=int)\n",
+ "for i, review in enumerate(encoded_reviews):\n",
+ " review_len=len(review)\n",
+ " if (review_len<=sequence_length):\n",
+ " zeros=list(np.zeros(sequence_length-review_len))\n",
+ " new=zeros+review\n",
+ " else:\n",
+ " new=review[:sequence_length]\n",
+ "features[i,:]=np.array(new)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "q-njY7CiKk4T",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "labels=train_data['label']"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "KSkcQMknKF-k",
+ "colab_type": "code",
+ "outputId": "f00d2da4-ec29-4820-a807-6b70e703248b",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ }
+ },
+ "source": [
+ "train_x=features[:int(0.8*len(features))]\n",
+ "train_y=labels[:int(0.8*len(features))]\n",
+ "valid_x=features[int(0.8*len(features)):]\n",
+ "valid_y=labels[int(0.8*len(features)):]\n",
+ "train_y=train_y.to_numpy()\n",
+ "valid_y=valid_y.to_numpy()\n",
+ "print(type(train_x), len(valid_y))\n",
+ "print(len(train_data))"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ " 202155\n",
+ "1010773\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "zTy3ma7HNomd",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "traindata = TensorDataset(torch.from_numpy(train_x), torch.from_numpy(train_y))\n",
+ "validdata = TensorDataset(torch.from_numpy(valid_x), torch.from_numpy(valid_y))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "FFDDBYvtXvzX",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "batch_size=50\n",
+ "train_loader=DataLoader(traindata, batch_size=batch_size, shuffle=True)\n",
+ "valid_loader=DataLoader(validdata, batch_size=batch_size, shuffle=True)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "KnhnFbL0XwFr",
+ "colab_type": "code",
+ "outputId": "2d496352-0ae6-4978-d1da-073ecda2a5f2",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 272
+ }
+ },
+ "source": [
+ "dataiter = iter(train_loader)\n",
+ "sample_x, sample_y = dataiter.next()\n",
+ "print('Sample input size: ', sample_x.size()) # batch_size, seq_length\n",
+ "print('Sample input: \\n', sample_x)\n",
+ "print()\n",
+ "print('Sample label size: ', sample_y.size()) # batch_size\n",
+ "print('Sample label: \\n', sample_y)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Sample input size: torch.Size([50, 500])\n",
+ "Sample input: \n",
+ " tensor([[0, 0, 0, ..., 0, 0, 0],\n",
+ " [0, 0, 0, ..., 0, 0, 0],\n",
+ " [0, 0, 0, ..., 0, 0, 0],\n",
+ " ...,\n",
+ " [0, 0, 0, ..., 0, 0, 0],\n",
+ " [0, 0, 0, ..., 0, 0, 0],\n",
+ " [0, 0, 0, ..., 0, 0, 0]])\n",
+ "\n",
+ "Sample label size: torch.Size([50])\n",
+ "Sample label: \n",
+ " tensor([1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0,\n",
+ " 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0,\n",
+ " 0, 1])\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "lov_X4LhZZSP",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import torch.nn as nn\n",
+ "\n",
+ "class SentimentLSTM(nn.Module):\n",
+ " \n",
+ "\n",
+ " def __init__(self, vocab_size, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5):\n",
+ " \"\"\"\n",
+ " Initialize the model by setting up the layers.\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ "\n",
+ " self.output_size = output_size\n",
+ " self.n_layers = n_layers\n",
+ " self.hidden_dim = hidden_dim\n",
+ " \n",
+ " \n",
+ " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
+ " self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, \n",
+ " dropout=drop_prob, batch_first=True)\n",
+ " \n",
+ " \n",
+ " self.dropout = nn.Dropout(0.3)\n",
+ " \n",
+ " \n",
+ " self.fc = nn.Linear(hidden_dim, output_size)\n",
+ " self.sig = nn.Sigmoid()\n",
+ " \n",
+ "\n",
+ " def forward(self, x, hidden):\n",
+ " \n",
+ " batch_size = x.size(0)\n",
+ "\n",
+ " \n",
+ " embeds = self.embedding(x)\n",
+ " lstm_out, hidden = self.lstm(embeds, hidden)\n",
+ " \n",
+ " \n",
+ " lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)\n",
+ " out = self.dropout(lstm_out)\n",
+ " out = self.fc(out)\n",
+ " \n",
+ " sig_out = self.sig(out)\n",
+ " \n",
+ " sig_out = sig_out.view(batch_size, -1)\n",
+ " sig_out = sig_out[:, -1] \n",
+ " return sig_out, hidden\n",
+ " \n",
+ " \n",
+ " def init_hidden(self, batch_size):\n",
+ " \n",
+ " weight = next(self.parameters()).data\n",
+ " \n",
+ " if (train_on_gpu):\n",
+ " hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda(),\n",
+ " weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda())\n",
+ " else:\n",
+ " hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_(),\n",
+ " weight.new(self.n_layers, batch_size, self.hidden_dim).zero_())\n",
+ " \n",
+ " return hidden\n",
+ " "
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "3IoUHKG6bkGb",
+ "colab_type": "code",
+ "outputId": "207f1a1b-3769-4bac-8a5f-2a3418ad4048",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 136
+ }
+ },
+ "source": [
+ "vocab_size = len(vocab_to_int)+1 # +1 for the 0 padding\n",
+ "output_size = 1\n",
+ "embedding_dim = 400\n",
+ "hidden_dim = 256\n",
+ "n_layers = 2\n",
+ "net = SentimentLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)\n",
+ "print(net)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "SentimentLSTM(\n",
+ " (embedding): Embedding(206968, 400)\n",
+ " (lstm): LSTM(400, 256, num_layers=2, batch_first=True, dropout=0.5)\n",
+ " (dropout): Dropout(p=0.3)\n",
+ " (fc): Linear(in_features=256, out_features=1, bias=True)\n",
+ " (sig): Sigmoid()\n",
+ ")\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "f26FXAbUeS72",
+ "colab_type": "code",
+ "outputId": "6191d34f-f788-415c-9e8d-6e8ed84dffbb",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 394
+ }
+ },
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "train_on_gpu = torch.cuda.is_available()\n",
+ "lr=0.001\n",
+ "\n",
+ "criterion = nn.BCELoss()\n",
+ "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
+ "\n",
+ "epochs = 4 \n",
+ "\n",
+ "counter = 0\n",
+ "print_every = 100\n",
+ "clip=5\n",
+ "if(train_on_gpu):\n",
+ " net.cuda()\n",
+ "\n",
+ "net.train()\n",
+ "for e in range(epochs):\n",
+ " \n",
+ " h = net.init_hidden(batch_size)\n",
+ " for inputs, labels in train_loader:\n",
+ " counter += 1\n",
+ "\n",
+ " if(train_on_gpu):\n",
+ " inputs, labels = inputs.cuda(), labels.cuda()\n",
+ "\n",
+ " h = tuple([each.data for each in h])\n",
+ "\n",
+ " # zero accumulated gradients\n",
+ " net.zero_grad()\n",
+ "\n",
+ " # get the output from the model\n",
+ " inputs = inputs.type(torch.cuda.LongTensor)\n",
+ " output,h = net(inputs, h)\n",
+ " print(output.shape)\n",
+ " # calculate the loss and perform backprop\n",
+ " loss = criterion(output.squeeze(), labels.float())\n",
+ " loss.backward(retain_graph=True)\n",
+ " \n",
+ " nn.utils.clip_grad_norm_(net.parameters(), clip)\n",
+ " optimizer.step()\n",
+ "\n",
+ " \n",
+ " if counter % print_every == 0:\n",
+ " \n",
+ " val_h = net.init_hidden(batch_size)\n",
+ " val_losses = []\n",
+ " net.eval()\n",
+ " for inputs, labels in valid_loader:\n",
+ "\n",
+ " \n",
+ " val_h = tuple([each.data for each in val_h])\n",
+ "\n",
+ " if(train_on_gpu):\n",
+ " inputs, labels = inputs.cuda(), labels.cuda()\n",
+ "\n",
+ " inputs = inputs.type(torch.cuda.LongTensor)\n",
+ " output, val_h = net(inputs, val_h)\n",
+ " val_loss = criterion(output.squeeze(), labels.float())\n",
+ "\n",
+ " val_losses.append(val_loss.item())\n",
+ "\n",
+ " net.train()\n",
+ " print(\"Epoch: {}/{}...\".format(e+1, epochs),\n",
+ " \"Step: {}...\".format(counter),\n",
+ " \"Loss: {:.6f}...\".format(loss.item()),\n",
+ " \"Val Loss: {:.6f}\".format(np.mean(val_losses)))"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "error",
+ "ename": "RuntimeError",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_h\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_h\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0mval_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 491\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 494\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, hidden)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0membeds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mlstm_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlstm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0membeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 491\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 494\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 495\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 557\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_packed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 559\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 560\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 561\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward_tensor\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 537\u001b[0m \u001b[0munsorted_indices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 539\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_batch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 540\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 541\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute_hidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0munsorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward_impl\u001b[0;34m(self, input, hx, batch_sizes, max_batch_size, sorted_indices)\u001b[0m\n\u001b[1;32m 517\u001b[0m \u001b[0mhx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute_hidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 518\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 519\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_forward_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 520\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_sizes\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 521\u001b[0m result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mcheck_forward_args\u001b[0;34m(self, input, hidden, batch_sizes)\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m self.check_hidden_size(hidden[0], expected_hidden_size,\n\u001b[0;32m--> 494\u001b[0;31m 'Expected hidden[0] size {}, got {}')\n\u001b[0m\u001b[1;32m 495\u001b[0m self.check_hidden_size(hidden[1], expected_hidden_size,\n\u001b[1;32m 496\u001b[0m 'Expected hidden[1] size {}, got {}')\n",
+ "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mcheck_hidden_size\u001b[0;34m(self, hx, expected_hidden_size, msg)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;31m# type: (Tensor, Tuple[int, int, int], str) -> None\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mexpected_hidden_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 172\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpected_hidden_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcheck_forward_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mRuntimeError\u001b[0m: Expected hidden[0] size (2, 5, 256), got (2, 50, 256)"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "eB_CGTGET05V",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "BSheCkROT095",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "M3CXZzIET1Iz",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "0vfJvTiaT1O2",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "UeGe55t6T1Sk",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "tWbk1HpzT1W0",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "5cyxIOvlT1bi",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "AhywpMP7T1gU",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "_MXann70T1mB",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file