Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions LSTM-Variational-AutoEncoder
Submodule LSTM-Variational-AutoEncoder added at 50476d
50 changes: 27 additions & 23 deletions main.py → Main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from data.ptb import PTB
from VGDLDataGeneralized.ptb import PTB

import torch
from loss import VAE_Loss
Expand All @@ -17,31 +17,31 @@
torch.manual_seed(global_setting["seed"])


parser = argparse.ArgumentParser(description=" A parser for baseline uniform noisy experiment")
parser.add_argument("--batch_size", type=str, default="32")
parser.add_argument("--bptt", type=str,default="60")
parser.add_argument("--embed_size", type=str, default="300")
parser.add_argument("--hidden_size", type=str, default="256")
parser.add_argument("--latent_size", type=str, default="16")
"""parser = argparse.ArgumentParser(description=" A parser for baseline uniform noisy experiment")
parser.add_argument("--batch_size", type=str, default="20")
parser.add_argument("--bptt", type=str,default="600")
parser.add_argument("--embed_size", type=str, default="50")
parser.add_argument("--hidden_size", type=str, default="400")
parser.add_argument("--latent_size", type=str, default="60")
parser.add_argument("--lr", type=str, default="0.001")


# Extract commandline arguments
args = parser.parse_args()

batch_size = int(args.batch_size) if args.batch_size!=None else training_setting["batch_size"]
bptt = int(args.bptt) if args.bptt!=None else training_setting["bptt"]
embed_size = int(args.embed_size) if args.embed_size!=None else training_setting["embed_size"]
hidden_size = int(args.hidden_size) if args.hidden_size!=None else training_setting["hidden_size"]
latent_size = int(args.latent_size) if args.latent_size!=None else training_setting["latent_size"]
lr = float(args.lr) if args.lr!=None else training_setting["lr"]
args = parser.parse_args()"""

batch_size = training_setting["batch_size"]
bptt = training_setting["bptt"]
embed_size = model_setting["embed_size"]
hidden_size = model_setting["hidden_size"]
latent_size = model_setting["latent_size"]
lr = training_setting["lr"]

data_dir = "./VGDLDataGeneralized"

# Load the data
train_data = PTB(data_dir="./data", split="train", create_data= False, max_sequence_length= bptt)
test_data = PTB(data_dir="./data", split="test", create_data= False, max_sequence_length=bptt)
valid_data = PTB(data_dir="./data", split="valid", create_data= False, max_sequence_length= bptt)
train_data = PTB(data_dir=data_dir, split="train", create_data= False, max_sequence_length= bptt)
test_data = PTB(data_dir=data_dir, split="test", create_data= False, max_sequence_length=bptt)
valid_data = PTB(data_dir=data_dir, split="valid", create_data= False, max_sequence_length= bptt)

# Batchify the data
train_loader = torch.utils.data.DataLoader( dataset= train_data, batch_size=batch_size, shuffle= True)
Expand All @@ -51,18 +51,15 @@


vocab_size = train_data.vocab_size
model = LSTM_VAE(vocab_size = vocab_size, embed_size = embed_size, hidden_size = hidden_size, latent_size = latent_size).to(device)
model = LSTM_VAE(vocab_size = vocab_size, embed_size = embed_size, hidden_size = hidden_size, latent_size = latent_size, data_dir=data_dir).to(device)

Loss = VAE_Loss()
optimizer = torch.optim.Adam(model.parameters(), lr= training_setting["lr"])

trainer = Trainer(train_loader, test_loader, model, Loss, optimizer)




if __name__ == "__main__":

def main():
# Epochs
train_losses = []
test_losses = []
Expand All @@ -72,11 +69,18 @@
train_losses = trainer.train(train_losses, epoch, training_setting["batch_size"], training_setting["clip"])
print("Testing.......")
test_losses = trainer.test(test_losses, epoch, training_setting["batch_size"])
if epoch % 50 == 0:
torch.save(model.state_dict(), "models/VGDL_VAE_GENERALIZED2_" + str(epoch) + ".pt")


plot_elbo(train_losses, "train")
plot_elbo(test_losses, "test")

torch.save(model.state_dict(), "models/VGDL_VAE2.pt")

if __name__ == "__main__":

main()



52 changes: 52 additions & 0 deletions VGDLData/CreateSplits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@


import csv
import random
import math


if __name__ == "__main__":
all_files = []
percent_train = 0.8
percent_valid = 0.1
percent_test = 0.1

with open("VGDLData/examples/all_games_sp.csv") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=',')
for row in csv_reader:
all_files.append(row[1])
random.shuffle(all_files)


train_index = round(len(all_files) * percent_train)
test_index = train_index + round(len(all_files) * percent_test)


training_files = all_files[:train_index]
test_files = all_files[train_index:test_index]
valid_files = all_files[test_index:]

with open("VGDLData/ptb.all.csv", "w") as file:
cvs_writer = csv.writer(file, delimiter=",")
for i, f in enumerate(all_files):
cvs_writer.writerow([i, f])

with open("VGDLData/ptb.train.csv", "w") as file:
cvs_writer = csv.writer(file, delimiter=",")
for i, f in enumerate(training_files):
cvs_writer.writerow([i, f])

with open("VGDLData/ptb.test.csv", "w") as file:
cvs_writer = csv.writer(file, delimiter=",")
for i, f in enumerate(test_files):
cvs_writer.writerow([i, f])

with open("VGDLData/ptb.valid.csv", "w") as file:
cvs_writer = csv.writer(file, delimiter=",")
for i, f in enumerate(valid_files):
cvs_writer.writerow([i, f])





Binary file added VGDLData/__pycache__/ptb.cpython-310.pyc
Binary file not shown.
Binary file added VGDLData/__pycache__/ptb.cpython-37.pyc
Binary file not shown.
Binary file added VGDLData/__pycache__/utils.cpython-310.pyc
Binary file not shown.
Binary file added VGDLData/__pycache__/utils.cpython-37.pyc
Binary file not shown.
8 changes: 8 additions & 0 deletions VGDLData/createData.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ptb import PTB
if __name__ == "__main__":
data_dir = "VGDLData"

all = PTB(data_dir, "all",create_data=True)
training = PTB(data_dir, "train",create_data=True)
test = PTB(data_dir, "test",create_data=True)
valid = PTB(data_dir, "valid",create_data=True)
64 changes: 64 additions & 0 deletions VGDLData/examples/2player/accelerator.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
BasicGame no_players=2

SpriteSet
land > Immovable hidden=True img=oryx/grass autotiling=True
structure > Immovable
water > color=BLUE img=newset/water4
goal > Door color=GREEN img=newset/exit2
winner >
winnerA > singleton=True
winnerB > singleton=True
log > Missile orientation=RIGHT speed=0.05 color=BROWN img=newset/logm
logA >
logB >
spawnPull > SpawnPoint stype=pull prob=0.1 hidden=True invisible=True
pull > Missile hidden=True invisible=True orientation=DOWN speed=0.05
avatar > ShootAvatar speed=0.2
avatarA > stype=pushA img=newset/girl2 frameRate=8
avatarB > stype=pushB img=newset/man3 frameRate=8
push > Flicker
pushA > singleton=True
pushB > singleton=True
wall > Immovable color=BLACK img=oryx/tree2

InteractionSet
avatar wall > stepBack pixelPerfect=True

avatarA goal > transformTo stype=winnerA scoreChange=1,0
avatarB goal > transformTo stype=winnerB scoreChange=0,1

logA pushA > increaseSpeedToAll stype=logA value=0.001
pushA logA > killSprite
logB pushB > increaseSpeedToAll stype=logB value=0.001
pushB logB > killSprite

avatar log > pullWithIt
avatar pull > pullWithIt
spawnPull log > pullWithIt

avatar log > shieldFrom ftype=killSprite stype=water
pull log > shieldFrom ftype=killSprite stype=water
log wall > stepBack pixelPerfect=True

avatarA water > killSprite
avatarB water > killSprite
pull water wall pull > killSprite

TerminationSet
MultiSpriteCounter stype1=winnerA stype2=winnerB limit=2 win=True,True
MultiSpriteCounter stype1=winnerA limit=1 win=True,False
MultiSpriteCounter stype1=winnerB limit=1 win=False,True
SpriteCounter stype=avatar limit=0 win=False,False
SpriteCounter stype=avatarA limit=0 win=False,True
SpriteCounter stype=avatarB limit=0 win=True,False

LevelMapping
g > goal water
. > water land
= > water logA spawnPull
+ > water logA
- > water logB spawnPull
_ > water logB
A > avatarA logA water
B > avatarB logB water
w > land wall
13 changes: 13 additions & 0 deletions VGDLData/examples/2player/accelerator_lvl0.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w..............................gw
w=====.........................gw
w++A++.........................gw
w+++++.........................gw
w..............................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w..............................gw
w-----.........................gw
w__B__.........................gw
w_____.........................gw
w..............................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
13 changes: 13 additions & 0 deletions VGDLData/examples/2player/accelerator_lvl1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w...................w..........gw
w=====..............ww.........gw
w++A++.........................gw
w+++++.........................gw
w..............................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w...................w..........gw
w-----..............ww.........gw
w__B__.........................gw
w_____.........................gw
w..............................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
15 changes: 15 additions & 0 deletions VGDLData/examples/2player/accelerator_lvl2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w...............................w
w=====..............ww..........w
w++A++..........................w
w+++++.........................gw
w+++++..............ww.........gw
w..............................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w...............................w
w-----..............ww..........w
w__B__..........................w
w_____.........................gw
w_____..............ww.........gw
w..............................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
5 changes: 5 additions & 0 deletions VGDLData/examples/2player/accelerator_lvl3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w++A++...............................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w__B__...............................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
11 changes: 11 additions & 0 deletions VGDLData/examples/2player/accelerator_lvl4.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w....................................gw
w==A==...............................gw
w+++++...............................gw
w....................................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
w....................................gw
w--B--...............................gw
w_____...............................gw
w....................................gw
wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww
46 changes: 46 additions & 0 deletions VGDLData/examples/2player/akkaarrh.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
BasicGame no_players=2 square_size=30
SpriteSet
background > Immovable hidden=True img=oryx/space1
ship > Immovable color=GREEN portal=True
nokey > img=oryx/spaceship1
withkey > img=oryx/spaceship2
explosion > Flicker limit=5 img=oryx/sparkle3

movable >
avatar > ShootAvatar stype=explosion
avatarA > img=newset/spaceman1
avatarB > img=newset/spaceman2
incoming >
incoming_slow > Chaser stype=ship color=ORANGE speed=0.05 img=oryx/alien3
incoming_fast > Chaser stype=ship color=YELLOW speed=0.15 img=oryx/alien1
enemySpawn > BomberRandomMissile stypeMissile=incoming_slow,incoming_fast invisible=True hidden=True singleton=True cooldown=8 speed=0.8 prob=0.1

winner > Immovable img=oryx/spaceship2

key > Immovable img=oryx/key1 shrinkfactor=0.7
wall > Immovable img=oryx/planet

LevelMapping
. > background
s > nokey background
e > enemySpawn background
k > key background
A > avatarA background
B > avatarB background
w > wall background

InteractionSet
enemySpawn wall > reverseDirection
movable wall > stepBack pixelPerfect=True
avatar nokey > stepBack
incoming ship > killBoth scoreChange=-1,-1
incoming explosion avatarA > killSprite scoreChange=2,0
incoming explosion avatarB > killSprite scoreChange=0,2
avatar key > transformToAll stype=nokey stypeTo=withkey
key avatarA > killSprite scoreChange=10,0
key avatarB > killSprite scoreChange=0,10
avatar withkey > transformTo stype=winner

TerminationSet
SpriteCounter stype=ship win=False,False
MultiSpriteCounter stype1=winner limit=2 win=True,True
12 changes: 12 additions & 0 deletions VGDLData/examples/2player/akkaarrh_lvl0.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
wwwwwwwwwwwwwwwwwwwwwwww
w...........e..........w
w......................w
w.k....................w
w......................w
wwww................wwww
w.........ww.ww........w
w......................w
w..w................w..w
w..w................w..w
w..w.....A..s..B....w..w
wwwwwwwwwwwwwwwwwwwwwwww
12 changes: 12 additions & 0 deletions VGDLData/examples/2player/akkaarrh_lvl1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
wwwwwwwwwwwwwwwwwwwwwwww
w...........e..........w
w......................w
w..........www.........w
w......................w
w.ww.....A.....B....ww.w
w......................w
w......................w
w......w..........wwwwww
w...s..w............k..w
w......w...............w
wwwwwwwwwwwwwwwwwwwwwwww
17 changes: 17 additions & 0 deletions VGDLData/examples/2player/akkaarrh_lvl2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
wwwwwwwwwwwwwwwww
w....e..........w
w...............w
w...............w
w...............w
w......A........w
w....www...wwww.w
w..........w....w
w...............w
w.......w....s..w
w.......w.......w
w........w......w
w..k............w
w..ww...........w
w...w.....www...w
w.........B.....w
wwwwwwwwwwwwwwwww
Loading