Rehaul request reading, add preliminary support for Gopher+
This commit is contained in:
parent
cf66f1eee2
commit
1ddaaf5bf0
1 changed files with 141 additions and 42 deletions
183
neomi.py
183
neomi.py
|
@ -5,6 +5,7 @@ import socket
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
class default_config: None
|
class default_config: None
|
||||||
|
|
||||||
|
@ -33,6 +34,15 @@ class OneArgumentException(Exception):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.text % self.argument
|
return self.text % self.argument
|
||||||
|
|
||||||
|
class UnreachableException(Exception):
|
||||||
|
def __str__(self):
|
||||||
|
return 'Declared unreachable'
|
||||||
|
|
||||||
|
# unreachable() → (Never returns)
|
||||||
|
# Used to mark a codepath that should never execute
|
||||||
|
def unreachable():
|
||||||
|
raise UnreachableException
|
||||||
|
|
||||||
# bind(port, backlog = 1) → [sockets...]
|
# bind(port, backlog = 1) → [sockets...]
|
||||||
# Binds to all available (TCP) interfaces on specified port and returns the sockets
|
# Binds to all available (TCP) interfaces on specified port and returns the sockets
|
||||||
# backlog controls how many connections allowed to wait handling before system drops new ones
|
# backlog controls how many connections allowed to wait handling before system drops new ones
|
||||||
|
@ -41,7 +51,7 @@ def bind(port, backlog = 1):
|
||||||
sockets = []
|
sockets = []
|
||||||
for res in socket.getaddrinfo(None, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
|
for res in socket.getaddrinfo(None, port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
|
||||||
af, socktype, proto, canonname, sa = res
|
af, socktype, proto, canonname, sa = res
|
||||||
|
|
||||||
try:
|
try:
|
||||||
s = socket.socket(af, socktype, proto)
|
s = socket.socket(af, socktype, proto)
|
||||||
except OSError:
|
except OSError:
|
||||||
|
@ -60,7 +70,7 @@ def bind(port, backlog = 1):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sockets.append(s)
|
sockets.append(s)
|
||||||
|
|
||||||
return sockets
|
return sockets
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,12 +85,44 @@ def drop_privileges():
|
||||||
except:
|
except:
|
||||||
die('Unable to drop privileges')
|
die('Unable to drop privileges')
|
||||||
|
|
||||||
|
class CommandError(OneArgumentException):
|
||||||
|
text = 'Error with command: %s'
|
||||||
|
|
||||||
|
class SocketReadError(OneArgumentException):
|
||||||
|
text = 'Error reading socket: %s'
|
||||||
|
|
||||||
|
class SocketReaderCommands(enum.Enum):
|
||||||
|
stop = range(1)
|
||||||
|
|
||||||
|
# SocketReader(sock) → <SocketReader instance>
|
||||||
|
# next(<SocketReader instance>) → byte_of_data
|
||||||
|
# Wraps a socket and exposes it as per-byte iterator. Does not close the socket when it exits
|
||||||
|
def SocketReader(sock):
|
||||||
|
chunk = b''
|
||||||
|
while True:
|
||||||
|
for byte in chunk:
|
||||||
|
command = yield byte
|
||||||
|
|
||||||
|
if command is not None:
|
||||||
|
if command == SocketReaderCommands.stop:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise CommandError('%s not recognised' % repr(command))
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = sock.recv(1024)
|
||||||
|
except socket.timeout:
|
||||||
|
raise SocketReadError('Error reading socket: Remote end timed out')
|
||||||
|
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
# extract_selector_path(selector_path, *, config) → selector, path
|
# extract_selector_path(selector_path, *, config) → selector, path
|
||||||
# Extract selector and path components from a HTTP path
|
# Extract selector and path components from a HTTP path
|
||||||
def extract_selector_path(selector_path, *, config):
|
def extract_selector_path(selector_path, *, config):
|
||||||
if len(selector_path) > 0 and selector_path[0] == '/':
|
if len(selector_path) > 0 and selector_path[0] == '/':
|
||||||
selector_path = selector_path[1:]
|
selector_path = selector_path[1:]
|
||||||
|
|
||||||
if len(selector_path) == 0: # / is by default of type 1
|
if len(selector_path) == 0: # / is by default of type 1
|
||||||
selector = '1'
|
selector = '1'
|
||||||
path = selector_path
|
path = selector_path
|
||||||
|
@ -90,7 +132,7 @@ def extract_selector_path(selector_path, *, config):
|
||||||
else: # We couldn't recognise any selector, return None for it
|
else: # We couldn't recognise any selector, return None for it
|
||||||
selector = None
|
selector = None
|
||||||
path = selector_path
|
path = selector_path
|
||||||
|
|
||||||
return selector, path
|
return selector, path
|
||||||
|
|
||||||
class PathError(OneArgumentException):
|
class PathError(OneArgumentException):
|
||||||
|
@ -122,58 +164,114 @@ def normalize_path(path, *, config):
|
||||||
else:
|
else:
|
||||||
# A normal path component, add to the normalized path
|
# A normal path component, add to the normalized path
|
||||||
normalized_components.append(component)
|
normalized_components.append(component)
|
||||||
|
|
||||||
return '/'.join(normalized_components)
|
|
||||||
|
|
||||||
class Protocol(enum.Enum):
|
return '/'.join(normalized_components)
|
||||||
gopher, http = range(2)
|
|
||||||
|
|
||||||
class RequestError(OneArgumentException):
|
class RequestError(OneArgumentException):
|
||||||
text = 'Error with handling request: %s'
|
text = 'Error with handling request: %s'
|
||||||
|
|
||||||
|
class Protocol(enum.Enum):
|
||||||
|
gopher, gopherplus, http = range(3)
|
||||||
|
|
||||||
# get_request(sock, *, config) → path, protocol, rest
|
# get_request(sock, *, config) → path, protocol, rest
|
||||||
# Read request from socket and parse it.
|
# Read request from socket and parse it.
|
||||||
# path is the requested path, protocol is Protocol.gopher or Protocol.http depending on the request protocol
|
# path is the requested path, protocol is Protocol.gopher or Protocol.http depending on the request protocol
|
||||||
# rest is protocol-dependant information
|
# rest is protocol-dependant information
|
||||||
def get_request(sock, *, config):
|
def get_request(sockreader, *, config):
|
||||||
request = b''
|
protocol = None
|
||||||
|
|
||||||
|
request = bytearray()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
data = sock.recv(1024)
|
request.append(next(sockreader))
|
||||||
except socket.timeout:
|
except StopIteration: # Other end hung up before sending a full header
|
||||||
raise RequestError('Remote end timed out')
|
|
||||||
if not data: # Other end hung up before sending a header
|
|
||||||
raise RequestError('Remote end hung up unexpectedly')
|
raise RequestError('Remote end hung up unexpectedly')
|
||||||
|
|
||||||
if len(request) >= config.request_max_size:
|
if len(request) >= config.request_max_size:
|
||||||
raise RequestError('Request too long')
|
raise RequestError('Request too long')
|
||||||
|
|
||||||
request += data
|
# We have enough data to recognise a HTTP request
|
||||||
|
if protocol is None and len(request) >= 4:
|
||||||
|
# Does it look like a HTTP GET request?
|
||||||
|
if request[:3] == bytearray(b'GET') and chr(request[3]) in [' ', '\r', '\t']:
|
||||||
|
# Yes, mark HTTP as protocol
|
||||||
|
protocol = Protocol.http
|
||||||
|
else:
|
||||||
|
# No, mark Gopher as protocol
|
||||||
|
protocol = Protocol.gopher
|
||||||
|
|
||||||
if b'\n' in data: # First line has been sent, all we care about for now
|
# End of line reached before a HTTP GET request found, mark Gopher as protocol
|
||||||
|
if protocol is None and len(request) >= 1 and request[-1:] == bytearray(b'\n'):
|
||||||
|
protocol = Protocol.gopher
|
||||||
|
|
||||||
|
# Twice CR+LF, end of HTTP request
|
||||||
|
if protocol == Protocol.http and len(request) >= 4 and request[-4:] == bytearray(b'\r\n\r\n'):
|
||||||
break
|
break
|
||||||
|
|
||||||
request = request.decode('utf-8')
|
# Twice LF, malcompliant but support anyways
|
||||||
first_line = request.split('\n')[0]
|
if protocol == Protocol.http and len(request) >=2 and request[-2:] == bytearray(b'\n\n'):
|
||||||
if first_line[-1] == '\r':
|
break
|
||||||
first_line = first_line[:-1]
|
|
||||||
first_line = first_line.split(' ')
|
# CR+LF, end of Gopher request
|
||||||
|
if protocol == Protocol.gopher and len(request) >= 2 and request[-2:] == bytearray(b'\r\n'):
|
||||||
if len(first_line) >= 2 and first_line[0] == 'GET':
|
break
|
||||||
selector_path = first_line[1]
|
|
||||||
|
# LF, malcompliant but support anyways
|
||||||
|
if protocol == Protocol.gopher and len(request) >= 1 and request[-1:] == bytearray(b'\n'):
|
||||||
|
break
|
||||||
|
|
||||||
|
if protocol == Protocol.http:
|
||||||
|
length = len(request)
|
||||||
|
# Start after GET
|
||||||
|
index = 3
|
||||||
|
# Skip witespace
|
||||||
|
while index < length and chr(request[index]) in [' ', '\r', '\n', '\t']: index += 1
|
||||||
|
# Found the start of the requested path
|
||||||
|
path_start = index
|
||||||
|
# Skip until next whitespace (end of requested path)
|
||||||
|
while index < length and chr(request[index]) not in [' ', '\r', '\n', '\t']: index += 1
|
||||||
|
# Found the end of the requested path
|
||||||
|
path_end = index
|
||||||
|
|
||||||
|
selector_path = urllib.parse.unquote(request[path_start:path_end].decode('utf-8'))
|
||||||
selector, path = extract_selector_path(selector_path, config = config)
|
selector, path = extract_selector_path(selector_path, config = config)
|
||||||
protocol = Protocol.http
|
|
||||||
rest = selector
|
rest = selector
|
||||||
else:
|
|
||||||
if len(first_line) >= 1:
|
elif protocol == Protocol.gopher:
|
||||||
path = first_line[0]
|
|
||||||
else:
|
|
||||||
path = ''
|
|
||||||
protocol = Protocol.gophrt
|
|
||||||
rest = None
|
rest = None
|
||||||
|
|
||||||
|
length = len(request)
|
||||||
|
index = 0
|
||||||
|
# Seek until either end of line or a tab (field separator)
|
||||||
|
while index < length and chr(request[index]) not in ['\t', '\r', '\n']: index += 1
|
||||||
|
# Found the end of the path
|
||||||
|
path_end = index
|
||||||
|
|
||||||
|
path = request[:path_end].decode('utf-8')
|
||||||
|
|
||||||
|
# If another field was present, check to see if it marks a Gopher+ request
|
||||||
|
if chr(request[index]) == '\t':
|
||||||
|
index += 1
|
||||||
|
field_start = index
|
||||||
|
# Look until end of line
|
||||||
|
while index < length and chr(request[index]) not in ['\r', '\n']: index += 1
|
||||||
|
field_end = index
|
||||||
|
|
||||||
|
field = request[field_start:field_end].decode('utf-8')
|
||||||
|
# We recognise these as signalling a Gopher+ request
|
||||||
|
if len(field) >= 1 and field[0] in ['+', '!', '$']:
|
||||||
|
# It was Gopher+, let's update protocol value and stash the field into rest
|
||||||
|
protocol = Protocol.gopherplus
|
||||||
|
rest = field
|
||||||
|
|
||||||
|
else:
|
||||||
|
unreachable()
|
||||||
|
|
||||||
path = normalize_path(path, config = config)
|
path = normalize_path(path, config = config)
|
||||||
|
|
||||||
return path, Protocol.gopher, None
|
return path, protocol, None
|
||||||
|
|
||||||
# Worker thread implementation
|
# Worker thread implementation
|
||||||
class Serve(threading.Thread):
|
class Serve(threading.Thread):
|
||||||
|
@ -183,12 +281,13 @@ class Serve(threading.Thread):
|
||||||
self.address = address
|
self.address = address
|
||||||
self.config = config
|
self.config = config
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
|
|
||||||
def handle_request(self):
|
def handle_request(self):
|
||||||
path, protocol, rest = get_request(self.sock, config = self.config)
|
sockreader = SocketReader(self.sock)
|
||||||
|
path, protocol, rest = get_request(sockreader, config = self.config)
|
||||||
answer = str((path, protocol, rest))+'\n'
|
answer = str((path, protocol, rest))+'\n'
|
||||||
self.sock.sendall(answer.encode('utf-8'))
|
self.sock.sendall(answer.encode('utf-8'))
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
global threads_amount, threads_lock
|
global threads_amount, threads_lock
|
||||||
|
|
||||||
|
@ -199,12 +298,12 @@ class Serve(threading.Thread):
|
||||||
finally:
|
finally:
|
||||||
self.sock.close()
|
self.sock.close()
|
||||||
self.controller.thread_end()
|
self.controller.thread_end()
|
||||||
|
|
||||||
class Threads_controller:
|
class Threads_controller:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.threads_amount = 0
|
self.threads_amount = 0
|
||||||
self.threads_lock = threading.Lock()
|
self.threads_lock = threading.Lock()
|
||||||
|
|
||||||
# .spawn_thread(sock, address, config)
|
# .spawn_thread(sock, address, config)
|
||||||
# Spawn a new thread to serve a connection if possible, do nothing if not
|
# Spawn a new thread to serve a connection if possible, do nothing if not
|
||||||
def spawn_thread(self, sock, address, config):
|
def spawn_thread(self, sock, address, config):
|
||||||
|
@ -216,10 +315,10 @@ class Threads_controller:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.threads_amount += 1
|
self.threads_amount += 1
|
||||||
|
|
||||||
# Spawn a new worker thread
|
# Spawn a new worker thread
|
||||||
Serve(self, sock, address, config).start()
|
Serve(self, sock, address, config).start()
|
||||||
|
|
||||||
# .thread_end()
|
# .thread_end()
|
||||||
# Called from worker thread to signal it's exiting
|
# Called from worker thread to signal it's exiting
|
||||||
def thread_end(self):
|
def thread_end(self):
|
||||||
|
@ -249,7 +348,7 @@ def listen(config):
|
||||||
|
|
||||||
# Create a controller object for the worker threads
|
# Create a controller object for the worker threads
|
||||||
threads_controller = Threads_controller()
|
threads_controller = Threads_controller()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Wait for listening sockets to get activity
|
# Wait for listening sockets to get activity
|
||||||
events = listening.poll()
|
events = listening.poll()
|
||||||
|
|
Loading…
Reference in a new issue