Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Platform with multiple servers control #19

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ venv/
/.pytest_cache/
/outline-install.log
__pycache__
.env
77 changes: 77 additions & 0 deletions outline_vpn/example_platform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from platform import Platform
from dotenv import load_dotenv
import os

load_dotenv()

servers = []
apis = []
certs = []

i = 1
while True:
server = os.getenv(f'SERVER_{i}')
if server is None:
break
cert = os.getenv(f'CERT_{i}')
if cert is None:
break
api = os.getenv(f'API_{i}')
if api is None:
break
servers += [server]
apis += [api]
certs += [cert]
i += 1

print(servers)
print(apis)
print(certs)

platform = Platform(servers, apis, certs)
print(platform)

print('-----Platfrom initialized-----\n')

current_users = platform.user_dict
print(f'Found {len(current_users)} users\n')
for user_name in current_users.keys():
print(current_users[user_name])

print('\n-----Adding users-----\n\n')
users_to_register = ['User1', 'User2', 'User3', 'User4']

for i, user in enumerate(users_to_register):
platform.create_new_key(user, limit_value=(i + 5) * 1000 * 1000)


current_users = platform.user_dict
print(f'Found {len(current_users)} users\n')
for user_name in current_users.keys():
print(current_users[user_name])

print('\n-----Setting new limits-----\n\n')
for i, user in enumerate(users_to_register):
print('Old: ' + platform.get_key(user).__str__())
print('data from get_balance: ', end='')
print(platform.get_balance(user))

platform.set_limit(user, 10 * (i + 5) * 1000 * 1000)
print('New: '+ platform.get_key(user).__str__())
print('data from get_balance: ', end='')
print(platform.get_balance(user))

platform.bump_limit(user, 123)
print('After bump: ' + platform.get_key(user).__str__())
print('data from get_balance: ', end='')
print(platform.get_balance(user))

print('\n-----Removing users-----\n\n')

for i, user in enumerate(users_to_register):
platform.remove_user(user)

current_users = platform.user_dict
print(f'Found {len(current_users)} users\n')
for user_name in current_users.keys():
print(current_users[user_name])
183 changes: 183 additions & 0 deletions outline_vpn/platform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from outline_vpn import OutlineVPN
import random
import datetime


class Platform:
def __init__(self, servers, apis, certs):
self.servers = servers # server names
self.apis = apis # API keys received from Outline server
self.certs = certs # certificates received from Outline server

self.vpn_servers = [] # list of OutlineVPN client objects
self.user_dict = {} # dict with users' keys

# initialize the servers
for i in range(len(servers)):
vpn_server = OutlineVPN(api_url=apis[i], cert_sha256=certs[i])
self.vpn_servers += [vpn_server]

# load the actual (last created for each user) keys from the servers
self.load_keys()

def load_keys(self, server_id=None):
user_dict_list = {} # temporary dict of lists to get all existing keys and then keep only last created per user
for i, vpn_server in enumerate(self.vpn_servers):
if server_id is not None:
if i != server_id:
continue
keys = vpn_server.get_keys()
if keys is not None:
for key in keys:
# ignoring any not standartized names like "{username string},{datetime like %Y-%m-%d %H:%M:%S}"
if Platform.check_name(key.name):
str_user_name = Platform.get_user_name(key.name)
if str_user_name in user_dict_list.keys():
user_dict_list[str_user_name] += [Key(key, i)]
else:
user_dict_list[str_user_name] = [Key(key, i)]

# to ensure the following correct work we need to remove the old keys
for str_user_id in user_dict_list.keys():
if len(user_dict_list[str_user_id]) > 1:
keys = user_dict_list[str_user_id]
dates = [self.get_date(k.name) for k in keys]

sorted_keys = [x for _, x in sorted(zip(dates, keys), key=lambda pair: pair[0])]
keep_key = sorted_keys[-1]
to_remove_keys = sorted_keys[:-1]
for key in to_remove_keys:
self.vpn_servers[key.server_id].delete_key(key.key_id)
self.user_dict[str_user_id] = keep_key
else:
self.user_dict[str_user_id] = user_dict_list[str_user_id][0]

@staticmethod
def is_valid_date(date_string):
try:
datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
return True
except ValueError:
return False

@staticmethod
def check_name(name):
len_ = len(name.split(','))
if len_ == 2:
dt = name.split(',')[1]
if Platform.is_valid_date(dt):
return True
else:
return False
else:
return False

@staticmethod
def get_date(name):
return name.split(',')[1]

@staticmethod
def get_user_name(name):
return name.split(',')[0]

def get_key(self, str_user_id):
if str_user_id in self.user_dict.keys():
return self.user_dict[str_user_id]
else:
return None

def remove_user(self, str_user_id):
if str_user_id in self.user_dict.keys():
key = self.user_dict[str_user_id]
del self.user_dict[str_user_id] # remove from cached (soft delete)
self.vpn_servers[key.server_id].delete_key(key.key_id) # hard delete can fail

def get_balance(self, str_user_id):
if str_user_id in self.user_dict.keys():
key = self.user_dict[str_user_id]
self.load_keys(server_id=key.server_id)

key = self.user_dict[str_user_id]
return key.used_bytes, key.data_limit
else:
return None, None

def set_limit(self, str_user_id, limit_value):
if str_user_id in self.user_dict.keys():
key = self.user_dict[str_user_id]
result = self.vpn_servers[key.server_id].add_data_limit(key.key_id, limit_value)
if result:
key.data_limit = limit_value
self.user_dict[str_user_id] = key
return True
else:
return False
else:
return False

def bump_limit(self, str_user_id, addition_value):
used, data_limit = self.get_balance(str_user_id)
if used is None:
return False # User not found
else:
key = self.user_dict[str_user_id]
data_limit_new = data_limit + addition_value
result = self.vpn_servers[key.server_id].add_data_limit(key.key_id, data_limit_new)
if result:
key.data_limit = data_limit_new
self.user_dict[str_user_id] = key
return True
else:
return False

def create_new_key(self, str_user_id, limit_value=None, forced_server_id=None):
# forced_server_id is an integer to choose which server id to use
last_server_id = -1
if str_user_id in self.user_dict.keys():
key = self.user_dict[str_user_id]

last_server_id = key.server_id
self.vpn_servers[key.server_id].delete_key(key.key_id)

if forced_server_id is None:
if last_server_id == -1:
new_server_id = random.randint(0, len(self.vpn_servers) - 1)
else:
new_server_id = (last_server_id + 1) % len(self.vpn_servers)
else:
new_server_id = forced_server_id

now = datetime.datetime.now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")

key = self.vpn_servers[new_server_id].create_key(str_user_id+','+now_str)
if limit_value is not None:
self.vpn_servers[new_server_id].add_data_limit(key.key_id, limit_value)

new_key = Key(key, new_server_id)
new_key.data_limit = limit_value
self.user_dict[str_user_id] = new_key

return new_key

def __str__(self):
print_line = f"Number of servers is {len(self.vpn_servers)}:\n"
for i in range(len(self.vpn_servers)):
print_line += f"-{self.servers[i]}\n"
return print_line


class Key:
def __init__(self, key, server_id):
self.key_id = key.key_id
self.name = key.name
self.password = key.password
self.port = key.port
self.method = key.method
self.access_url = key.access_url
self.used_bytes = 0 if key.used_bytes is None else key.used_bytes
self.data_limit = key.data_limit # might be None
self.server_id = server_id

def __str__(self):
return f"server #{self.server_id+1}, key_id {self.key_id}, name={self.name}, used_mb={self.used_bytes/1000000}, data_limit_mb={self.data_limit/1000000}"