# -*- coding: ascii -*-
#
#  Copyright (C) 2001, 2002 by Tamito KAJIYAMA
#  Copyright (C) 2002, 2003 by MATSUMURA Namihiko <nie@counterghost.net>
#  Copyright (C) 2002-2005 by Shyouzou Sugitani <shy@users.sourceforge.jp>
#
#  This program is free software; you can redistribute it and/or modify it
#  under the terms of the GNU General Public License (version 2) as
#  published by the Free Software Foundation.  It is distributed in the
#  hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
#  implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
#  PURPOSE.  See the GNU General Public License for more details.
#
# $Id: sstp.py,v 1.15.2.1 2005/05/27 07:38:05 shy Exp $
#

import codecs

import ninix.entry_db
import ninix.script
import ninix.version

from sstplib import AsynchronousSSTPServer, BaseSSTPRequestHandler

class SSTPServer(AsynchronousSSTPServer):

    def __init__(self, address, app):
        self.__app = app
        AsynchronousSSTPServer.__init__(self, address, SSTPRequestHandler)
        self.request_handler = None

    def send_response(self, code, data=None):
        try:
            self.request_handler.send_response(code)
            if data is not None:
                self.request_handler.wfile.write(data)
            self.request_handler.finish(1)
        except IOError:
            pass
        self.request_handler = None

    def send_answer(self, value):
        self.send_response(200, # OK
                           value.encode('utf-8', 'ignore') + "\r\nCharset: UTF-8\r\n")

    def send_no_content(self):
        self.send_response(204) # No Content

    def send_sstp_break(self):
        self.send_response(210) # Break

    def send_timeout(self):
        self.send_response(408) # Request Timeout

    def close(self):
        self.socket.close()

    def get_current_sakura(self):
        return self.__app.get_current_sakura()

    def if_ghost(self, ifghost):
        return self.__app.if_ghost(ifghost)

    def get_ghost_names(self):
        return self.__app.get_ghost_names()

    def check_request_queue(self, sender):
        return self.__app.check_request_queue(sender)

    def enqueue_script_if_ghost(self, if_ghost, script, sender, handle, address, show_sstp_marker, use_translator, entry_db):
        self.__app.enqueue_script_if_ghost(if_ghost, script, sender, handle, address, show_sstp_marker, use_translator, entry_db)

class SSTPRequestHandler(BaseSSTPRequestHandler):

    def handle(self):
        sakura = self.server.get_current_sakura()
        if not sakura.cantalk:
            self.error = self.version = None
            if not self.parse_request(self.rfile.readline()):
                return
            self.send_error(512)
        else:
            BaseSSTPRequestHandler.handle(self)

    def finish(self, forced=0):
        if forced or self.server.request_handler is None:
            BaseSSTPRequestHandler.finish(self)

    # SEND
    def do_SEND_1_0(self):
        self.handle_send(1.0)

    def do_SEND_1_1(self):
        self.handle_send(1.1)

    def do_SEND_1_2(self):
        self.handle_send(1.2)

    def do_SEND_1_3(self):
        self.handle_send(1.3)

    def do_SEND_1_4(self):
        self.handle_send(1.4)

    def handle_send(self, version):
        if not self.check_decoder():
            return
        sender = self.get_sender()
        if sender is None:
            return
        if version == 1.3:
            handle = self.get_handle()
            if handle is None:
                return
        else:
            handle = None
        if version == 1.4:
            script, if_ghost = self.get_script_if_ghost()
        else:
            script = self.get_script()
        if script is None or self.check_script(script):
            return
        if version in [1.0, 1.1]:
            entry_db = None
        elif version in [1.2, 1.3, 1.4]:
            entry_db = self.get_entry_db()
            if entry_db is None:
                return
        if version == 1.4:
            self.enqueue_script_if_ghost(if_ghost, sender, handle, script, entry_db)
        else:
            self.enqueue_script(sender, handle, script, entry_db)

    # NOTIFY
    def do_NOTIFY_1_0(self):
        self.handle_notify(1.0)

    def do_NOTIFY_1_1(self):
        self.handle_notify(1.1)

    def handle_notify(self, version):
        if not self.check_decoder():
            return
        sender = self.get_sender()
        if sender is None:
            return
        event = self.get_event()
        if event is None:
            return
        sakura = self.server.get_current_sakura()
        script = apply(sakura.get_event_response, event)
        if script and self.check_script(script):
            return
        if version == 1.0:
            entry_db = None
        elif version == 1.1:
            if not script:
                script, if_ghost = self.get_script_if_ghost(current=1)
                if script is None or self.check_script(script):
                    return
            entry_db = self.get_entry_db()
            if entry_db is None:
                return
        if not script:
            self.send_response(200) # OK
            return
        self.enqueue_script(sender, None, script, entry_db)

    def enqueue_script_if_ghost(self, if_ghost, sender, handle, script, entry_db):
        try:
            address = self.client_address[0]
        except:
            address = self.client_address
        self.send_response(200) # OK
        show_sstp_marker, use_translator = self.get_options()
        self.server.enqueue_script_if_ghost(
            if_ghost, script, sender, handle,
            address, show_sstp_marker, use_translator, entry_db)

    def enqueue_script(self, sender, handle, script, entry_db):
        try:
            address = self.client_address[0]
        except:
            address = self.client_address
        if entry_db is None or entry_db.is_empty():
            self.send_response(200) # OK
            show_sstp_marker, use_translator = self.get_options()
            sakura = self.server.get_current_sakura()
            sakura.enqueue_script(
                script, sender, handle,
                address, show_sstp_marker, use_translator)
        elif self.server.request_handler:
            self.send_response(409) # Conflict
        else:
            show_sstp_marker, use_translator = self.get_options()
            sakura = self.server.get_current_sakura()
            sakura.enqueue_script(
                script, sender, handle,
                address, show_sstp_marker, use_translator,
                entry_db, self.server)
            self.server.request_handler = self # keep alive

    PROHIBITED_TAGS = [r"\j", r"\-", r"\+", r"\_+", r"\!"]

    def check_script(self, script):
        if not self.local_request():
            parser = ninix.script.Parser()
            nodes = []
            while 1:
                try:
                    nodes.extend(parser.parse(script))
                except ninix.script.ParserError, e:
                    done, script = e
                    nodes.extend(done)
                else:
                    break
            for node in nodes:
                if node[0] == ninix.script.SCRIPT_TAG and \
                   node[1] in self.PROHIBITED_TAGS:
                    self.send_response(400) # Bad Request
                    self.log_error("Script: tag %s not allowed" % node[1])
                    return 1
        return 0

    def get_script(self):
        charset = self.headers.get("charset", "Shift_JIS")
        script = self.headers.get("script", None)
        if script is None:
            self.send_response(400) # Bad Request
            self.log_error("Script: header field not found")
            return None
        return unicode(script, charset, "replace")

    def get_script_if_ghost(self, current=0):
        charset = self.headers.get("charset", "Shift_JIS")
        default = None
        i, j = 0, len(self.headers.headers)
        while i < j:
            line = unicode(self.headers.headers[i], charset, "replace")
            i += 1
            if line[:8].lower() == "ifghost:" and i < j:
                if_ghost = line[8:].strip()
                line = unicode(self.headers.headers[i], charset, "replace")
                i += 1
                if line[:7].lower() != "script:":
                    continue
                script = line[7:].strip()
                if current: # NOTIFY
                    ghost = self.ghost_names()
                    if ghost == if_ghost:
                        return script, if_ghost
                else: # SEND
                    if self.server.if_ghost(if_ghost):
                        return script, if_ghost
                if default is None:
                    default = script, if_ghost
        if default is None:
            return self.get_script(), None
        return default

    def get_entry_db(self):
        charset = self.headers.get("charset", "Shift_JIS")
        entry_db = ninix.entry_db.EntryDatabase()
        for line in self.headers.getallmatchingheaders("entry"):
            value = unicode(line[6:], charset, "replace")
            entry = value.split(",", 1)
            if len(entry) != 2:
                self.send_response(400) # Bad Request
                return None
            entry_db.add(entry[0].strip(), entry[1].strip())
        return entry_db

    def get_event(self):
        charset = self.headers.get("charset", "Shift_JIS")
        event = self.headers.get("event", None)
        if event is None:
            self.send_response(400) # Bad Request
            self.log_error("Event: header field not found")
            return None
        buffer = [unicode(event, charset, "replace")]
        for i in range(8):
            value = self.headers.get("reference" + str(i), None)
            if value is not None:
                value = unicode(value, charset, "replace")
            buffer.append(value)
        return tuple(buffer)

    def get_sender(self):
        charset = self.headers.get("charset", "Shift_JIS")
        sender = self.headers.get("sender", None)
        if sender is None:
            self.send_response(400) # Bad Request
            self.log_error("Sender: header field not found")
            return None
        return unicode(sender, charset, "replace")

    def get_handle(self):
        path = self.headers.get("hwnd", None)
        if path is None:
            self.send_response(400) # Bad Request
            self.log_error("HWnd: header field not found")
            return None
        sakura = self.server.get_current_sakura()
        handle = sakura.open_sstp_handle(path)
        if handle is None:
            self.send_response(400) # Bad Request
            self.log_error("Invalid HWnd: header field")
            return None
        return handle

    def check_decoder(self):
        charset = self.headers.get("charset", "Shift_JIS")
        try:
            codecs.lookup(charset)
        except:
            self.send_response(420, "Refuse (unsupported charset)")
            self.log_error("Unsupported charset %s" % repr(charset))
        else:
            return 1
        return 0

    def get_options(self):
        show_sstp_marker = use_translator = 1
        for option in self.headers.get("option", "").split(","):
            option = option.strip()
            if option == "nodescript" and self.local_request():
                show_sstp_marker = 0
            elif option == "notranslate":
                use_translator = 0
        return show_sstp_marker, use_translator

    def local_request(self):
        result = 0
        try:
            path = self.client_address
            result = 1
        except:
            host, port = self.client_address
            result = host ==  "127.0.0.1"
        return result

    def ghost_names(self):
        sakura = self.server.get_current_sakura()
        return "%s,%s" % (sakura.get_selfname(),
                          sakura.get_keroname())

    # EXECUTE
    def do_EXECUTE_1_0(self):
        self.handle_command()

    def do_EXECUTE_1_2(self):
        self.handle_command()

    def do_EXECUTE_1_3(self):
        if not self.local_request():
            host, port = self.client_address
            self.send_response(420)
            self.log_error("Unauthorized EXECUTE/1.3 request from %s" % host)
            return
        self.handle_command()

    def do_EXECUTE_1_5(self):
        if not self.local_request():
            host, port = self.client_address
            self.send_response(420)
            self.log_error("Unauthorized EXECUTE/1.5 request from %s" % host)
            return
        self.handle_command()

    def handle_command(self):
        if not self.check_decoder():
            return
        sender = self.get_sender()
        if sender is None:
            return
        command = self.get_command()
        if command is None:
            return
        elif command == "getname":
            self.send_response(200)
            name = self.ghost_names()
            self.wfile.write(name.encode('utf-8', 'ignore') + '\r\n')
            self.wfile.write('Charset: UTF-8\r\n')
        elif command == "getversion":
            self.send_response(200)
            self.wfile.write("ninix-aya %s\r\n" % ninix.version.VERSION)
        elif command == "quiet":
            self.send_response(200)
            sakura = self.server.get_current_sakura()
            sakura.keep_silence(1)
        elif command == "restore":
            self.send_response(200)
            sakura = self.server.get_current_sakura()
            sakura.keep_silence(0)
        elif command == "reload":
            self.send_response(200)
            sakura = self.server.get_current_sakura()
            sakura.reload()
        elif command == "getnames":
            self.send_response(200)
            for name in self.server.get_ghost_names():
                self.wfile.write(name.encode('utf-8', 'ignore') + '\r\n')
            self.wfile.write('Charset: UTF-8\r\n')
        elif command == "checkqueue":
            self.send_response(200)
            count, total = self.server.check_request_queue(sender)
            self.wfile.write(count + '\r\n')
            self.wfile.write(total + '\r\n')
        else:
            self.send_response(501) # Not Implemented
            self.log_error("Not Implemented (%s)" % command)

    def get_command(self):
        charset = self.headers.get("charset", "Shift_JIS")
        command = self.headers.get("command", None)
        if command is None:
            self.send_response(400) # Bad Request
            self.log_error("Command: header field not found")
            return None
        return unicode(command, charset, "replace").lower()

    def do_COMMUNICATE_1_1(self): ## FIXME
        if not self.check_decoder():
            return
        sender = self.get_sender()
        if sender is None:
            return
        sentence = self.get_sentence()
        if sentence is None:
            return
        self.send_response(200) # OK
        sakura = self.server.get_current_sakura()
        sakura.enqueue_event("OnCommunicate", sender, sentence)
        return

    def get_sentence(self):
        charset = self.headers.get("charset", "Shift_JIS")
        sentence = self.headers.get("sentence", None)
        if sentence is None:
            self.send_response(400) # Bad Request
            self.log_error("Sentence: header field not found")
            return None
        return unicode(sentence, charset, "replace")
