Bring code up to PEP8 standard

This commit is contained in:
Tanner Collin 2017-10-14 13:33:15 -06:00
parent c7309c4136
commit cf73e5ab21
5 changed files with 159 additions and 139 deletions

42
api.py
View File

@ -1,4 +1,4 @@
import json, requests, time
import requests
from crypt import EncryptionHelper
@ -18,47 +18,51 @@ class RESTAPI:
url = self.base_url + route
return requests.post(url, json=data, headers=self.headers).json()
def addHeader(self, header):
def add_header(self, header):
self.headers.update(header)
class StandardNotesAPI:
encryption_helper = EncryptionHelper()
sync_token = None
def getAuthParamsForEmail(self):
def get_auth_params_for_email(self):
return self.api.get('/auth/params', dict(email=self.username))
def genKeys(self, password):
pw_info = self.getAuthParamsForEmail()
def gen_keys(self, password):
pw_info = self.get_auth_params_for_email()
if 'error' in pw_info:
raise SNAPIException(pw_info['error']['message'])
return self.encryption_helper.pure_generatePasswordAndKey(password, pw_info['pw_salt'], pw_info['pw_cost'])
return self.encryption_helper.pure_generate_password_and_key(
password, pw_info['pw_salt'], pw_info['pw_cost'])
def signIn(self, keys):
def sign_in(self, keys):
self.keys = keys
res = self.api.post('/auth/sign_in', dict(email=self.username, password=self.keys['pw']))
res = self.api.post('/auth/sign_in', dict(email=self.username,
password=self.keys['pw']))
if 'error' in res:
raise SNAPIException(res['error']['message'])
self.api.addHeader(dict(Authorization='Bearer ' + res['token']))
self.api.add_header(dict(Authorization='Bearer ' + res['token']))
def sync(self, dirty_items):
items = self.handleDirtyItems(dirty_items)
response = self.api.post('/items/sync', dict(sync_token=self.sync_token, items=items))
items = self.handle_dirty_items(dirty_items)
response = self.api.post('/items/sync', dict(sync_token=self.sync_token,
items=items))
self.sync_token = response['sync_token']
return self.handleResponseItems(response)
return self.handle_response_items(response)
def handleDirtyItems(self, dirty_items):
items = self.encryption_helper.encryptDirtyItems(dirty_items, self.keys)
def handle_dirty_items(self, dirty_items):
items = self.encryption_helper.encrypt_dirty_items(
dirty_items, self.keys)
return items
def handleResponseItems(self, response):
response_items = self.encryption_helper.decryptResponseItems(response['retrieved_items'], self.keys)
saved_items = self.encryption_helper.decryptResponseItems(response['saved_items'], self.keys)
def handle_response_items(self, response):
response_items = self.encryption_helper.decrypt_response_items(
response['retrieved_items'], self.keys)
saved_items = self.encryption_helper.decrypt_response_items(
response['saved_items'], self.keys)
return dict(response_items=response_items, saved_items=saved_items)
def __init__(self, base_url, username):

View File

@ -1,14 +1,28 @@
import hashlib, hmac, json
from base64 import b64encode, b64decode
from base64 import b64decode, b64encode
from binascii import hexlify, unhexlify
from Crypto.Cipher import AES
from Crypto.Random import random
from copy import deepcopy
import hashlib
import hmac
import json
import sys
from Crypto.Cipher import AES
from Crypto.Random import random
BITS_PER_HEX_DIGIT = 4
PASS_KEY_LEN = 96
AES_KEY_LEN = 256
AES_BLK_SIZE = 16
AES_STR_KEY_LEN = AES_KEY_LEN // BITS_PER_HEX_DIGIT
AES_IV_LEN = 128
AES_STR_IV_LEN = AES_IV_LEN // BITS_PER_HEX_DIGIT
class EncryptionHelper:
def pure_generatePasswordAndKey(self, password, pw_salt, pw_cost):
output = hashlib.pbkdf2_hmac('sha512', password.encode(), pw_salt.encode(), pw_cost, dklen=96)
def pure_generate_password_and_key(self, password, pw_salt, pw_cost):
output = hashlib.pbkdf2_hmac(
'sha512', password.encode(), pw_salt.encode(), pw_cost,
dklen=PASS_KEY_LEN)
output = hexlify(output).decode()
output_length = len(output)
@ -19,36 +33,39 @@ class EncryptionHelper:
return dict(pw=pw, mk=mk, ak=ak)
def encryptDirtyItems(self, dirty_items, keys):
return [self.pure_encryptItem(item, keys) for item in dirty_items]
def encrypt_dirty_items(self, dirty_items, keys):
return [self.pure_encrypt_item(item, keys) for item in dirty_items]
def decryptResponseItems(self, response_items, keys):
return [self.pure_decryptItem(item, keys) for item in response_items]
def decrypt_response_items(self, response_items, keys):
return [self.pure_decrypt_item(item, keys) for item in response_items]
def pure_encryptItem(self, item, keys):
def pure_encrypt_item(self, item, keys):
uuid = item['uuid']
content = json.dumps(item['content'])
item_key = hex(random.getrandbits(512))
item_key = item_key[2:].rjust(128, '0') # remove '0x', pad to 128
item_key_length = len(item_key)
item_ek = item_key[:item_key_length//2]
item_ak = item_key[item_key_length//2:]
# all this is to follow the Standard Notes spec
item_key = hex(random.getrandbits(AES_KEY_LEN * 2))
# remove '0x', pad with 0's, then split in half
item_key = item_key[2:].rjust(AES_STR_KEY_LEN * 2, '0')
item_ek = item_key[:AES_STR_KEY_LEN]
item_ak = item_key[AES_STR_KEY_LEN:]
enc_item = deepcopy(item)
enc_item['content'] = self.pure_encryptString002(content, item_ek, item_ak, uuid)
enc_item['enc_item_key'] = self.pure_encryptString002(item_key, keys['mk'], keys['ak'], uuid)
enc_item['content'] = self.pure_encrypt_string_002(
content, item_ek, item_ak, uuid)
enc_item['enc_item_key'] = self.pure_encrypt_string_002(
item_key, keys['mk'], keys['ak'], uuid)
return enc_item
def pure_decryptItem(self, item, keys):
def pure_decrypt_item(self, item, keys):
if item['deleted']:
return item
uuid = item['uuid']
content = item['content']
enc_item_key = item['enc_item_key']
if not content:
return item
if content[:3] == '001':
print('Old encryption protocol detected. This version is not '
'supported by standardnotes-fs. Please resync all of '
@ -56,12 +73,14 @@ class EncryptionHelper:
'https://standardnotes.org/help/resync')
sys.exit(1)
elif content[:3] == '002':
item_key = self.pure_decryptString002(enc_item_key, keys['mk'], keys['ak'], uuid)
item_key = self.pure_decrypt_string_002(
enc_item_key, keys['mk'], keys['ak'], uuid)
item_key_length = len(item_key)
item_ek = item_key[:item_key_length//2]
item_ak = item_key[item_key_length//2:]
dec_content = self.pure_decryptString002(content, item_ek, item_ak, uuid)
dec_content = self.pure_decrypt_string_002(
content, item_ek, item_ak, uuid)
else:
print('Invalid protocol version. This could indicate tampering or '
'that something is wrong with the server. Exiting.')
@ -72,31 +91,30 @@ class EncryptionHelper:
return dec_item
def pure_encryptString002(self, string_to_encrypt, encryption_key, auth_key, uuid):
IV = hex(random.getrandbits(128))
IV = IV[2:].rjust(32, '0') # remove '0x', pad to 32
def pure_encrypt_string_002(self, string_to_encrypt, encryption_key,
auth_key, uuid):
IV = hex(random.getrandbits(AES_IV_LEN))
IV = IV[2:].rjust(AES_STR_IV_LEN, '0') # remove '0x', pad with 0's
cipher = AES.new(unhexlify(encryption_key), AES.MODE_CBC, unhexlify(IV))
pt = string_to_encrypt.encode()
pad = 16 - len(pt) % 16
pad = AES_BLK_SIZE - len(pt) % AES_BLK_SIZE
padded_pt = pt + pad * bytes([pad])
ciphertext = b64encode(cipher.encrypt(padded_pt)).decode()
string_to_auth = ':'.join(['002', uuid, IV, ciphertext])
auth_hash = hmac.new(unhexlify(auth_key), string_to_auth.encode(), 'sha256').digest()
auth_hash = hmac.new(
unhexlify(auth_key), string_to_auth.encode(), 'sha256').digest()
auth_hash = hexlify(auth_hash).decode()
result = ':'.join(['002', auth_hash, uuid, IV, ciphertext])
return result
def pure_decryptString002(self, string_to_decrypt, encryption_key, auth_key, uuid):
def pure_decrypt_string_002(self, string_to_decrypt, encryption_key,
auth_key, uuid):
components = string_to_decrypt.split(':')
version = components[0]
auth_hash = components[1]
local_uuid = components[2]
IV = components[3]
ciphertext = components[4]
version, auth_hash, local_uuid, IV, ciphertext = components
if local_uuid != uuid:
print('UUID does not match. This could indicate tampering or '
@ -104,7 +122,8 @@ class EncryptionHelper:
sys.exit(1)
string_to_auth = ':'.join([version, uuid, IV, ciphertext])
local_auth_hash = hmac.new(unhexlify(auth_key), string_to_auth.encode(), 'sha256').digest()
local_auth_hash = hmac.new(
unhexlify(auth_key), string_to_auth.encode(), 'sha256').digest()
local_auth_hash = hexlify(local_auth_hash).decode()
if local_auth_hash != auth_hash:

View File

@ -1,10 +1,11 @@
from api import StandardNotesAPI
from uuid import uuid1
from api import StandardNotesAPI
class ItemManager:
items = {}
def mapResponseItemsToLocalItems(self, response_items, metadata_only=False):
def map_items(self, response_items, metadata_only=False):
DATA_KEYS = ['content', 'enc_item_key', 'auth_hash']
for response_item in response_items:
@ -25,7 +26,7 @@ class ItemManager:
continue
self.items[uuid][key] = value
def syncItems(self):
def sync_items(self):
dirty_items = [item for uuid, item in self.items.items() if item['dirty']]
# remove keys (note: this removes them from self.items as well)
@ -34,12 +35,13 @@ class ItemManager:
item.pop('updated_at', None)
response = self.sn_api.sync(dirty_items)
self.mapResponseItemsToLocalItems(response['response_items'])
self.mapResponseItemsToLocalItems(response['saved_items'], metadata_only=True)
self.map_items(response['response_items'])
self.map_items(response['saved_items'], metadata_only=True)
def getNotes(self):
def get_notes(self):
notes = {}
sorted_items = sorted(self.items.items(), key=lambda x: x[1]['created_at'])
sorted_items = sorted(
self.items.items(), key=lambda x: x[1]['created_at'])
for uuid, item in sorted_items:
if item['content_type'] == 'Note':
@ -55,49 +57,45 @@ class ItemManager:
# remove title duplicates by adding a number to the end
count = 0
while True:
title = note['title'] + ('' if not count else ' ' + str(count + 1))
title = note['title'] + ('' if not count else
' ' + str(count + 1))
if title in notes:
count += 1
else:
break
notes[title] = dict(text=text,
created=item['created_at'],
notes[title] = dict(
text=text, created=item['created_at'],
modified=item.get('updated_at', item['created_at']),
uuid=item['uuid'])
return notes
def touchNote(self, uuid):
def touch_note(self, uuid):
item = self.items[uuid]
item['dirty'] = True
def writeNote(self, uuid, text):
def write_note(self, uuid, text):
item = self.items[uuid]
item['content']['text'] = text.decode() # convert back to string
item['dirty'] = True
def createNote(self, name, time):
def create_note(self, name, time):
uuid = str(uuid1())
content = dict(title=name, text='', references=[])
self.items[uuid] = dict(content_type='Note',
dirty=True,
auth_hash=None,
uuid=uuid,
created_at=time,
updated_at=time,
enc_item_key='',
content=content)
self.items[uuid] = dict(content_type='Note', dirty=True, auth_hash=None,
uuid=uuid, created_at=time, updated_at=time,
enc_item_key='', content=content)
def renameNote(self, uuid, new_note_name):
def rename_note(self, uuid, new_note_name):
item = self.items[uuid]
item['content']['title'] = new_note_name
item['dirty'] = True
def deleteNote(self, uuid):
def delete_note(self, uuid):
item = self.items[uuid]
item['deleted'] = True
item['dirty'] = True
def __init__(self, sn_api):
self.sn_api = sn_api
self.syncItems()
self.sync_items()

View File

@ -1,23 +1,22 @@
from datetime import datetime
import errno
import iso8601
import logging
import os
import time
from stat import S_IFDIR, S_IFREG
from sys import argv, exit
from datetime import datetime
from pathlib import PurePath
from threading import Thread, Event
from stat import S_IFDIR, S_IFREG
from threading import Event, Thread
from time import sleep
from fuse import FUSE, FuseOSError, Operations, LoggingMixIn
from itemmanager import ItemManager
from fuse import FuseOSError, LoggingMixIn, Operations
import iso8601
from requests.exceptions import ConnectionError
from itemmanager import ItemManager
class StandardNotesFUSE(LoggingMixIn, Operations):
def __init__(self, sn_api, sync_sec, path='.'):
self.item_manager = ItemManager(sn_api)
self.notes = self.item_manager.getNotes()
self.notes = self.item_manager.get_notes()
self.uid = os.getuid()
self.gid = os.getgid()
@ -34,36 +33,36 @@ class StandardNotesFUSE(LoggingMixIn, Operations):
self.sync_sec = sync_sec
self.run_sync = Event()
self.stop_sync = Event()
self.sync_thread = Thread(target=self._syncThread)
self.sync_thread = Thread(target=self._sync_thread)
def init(self, path):
self.sync_thread.start()
def destroy(self, path):
self._syncNow()
self._sync_now()
logging.info('Stopping sync thread.')
self.stop_sync.set()
self.sync_thread.join()
return 0
def _syncThread(self):
def _sync_thread(self):
while not self.stop_sync.is_set():
self.run_sync.clear()
manually_synced = self.run_sync.wait(timeout=self.sync_sec)
if not manually_synced: logging.info('Auto-syncing items...')
time.sleep(0.1) # fixes race condition of quick create() then write()
sleep(0.1) # fixes race condition of quick create() then write()
try:
self.item_manager.syncItems()
self.item_manager.sync_items()
except ConnectionError:
logging.error('Unable to connect to sync server. Retrying...')
def _syncNow(self):
def _sync_now(self):
self.run_sync.set()
def _pathToNote(self, path):
def _path_to_note(self, path):
pp = PurePath(path)
note_name = pp.parts[1]
self.notes = self.item_manager.getNotes()
self.notes = self.item_manager.get_notes()
note = self.notes[note_name]
return note, note['uuid']
@ -72,7 +71,7 @@ class StandardNotesFUSE(LoggingMixIn, Operations):
return self.dir_stat
try:
note, uuid = self._pathToNote(path)
note, uuid = self._path_to_note(path)
st = self.note_stat
st['st_size'] = len(note['text'])
st['st_ctime'] = iso8601.parse_date(note['created']).timestamp()
@ -89,27 +88,27 @@ class StandardNotesFUSE(LoggingMixIn, Operations):
return dirents
def read(self, path, size, offset, fh):
note, uuid = self._pathToNote(path)
note, uuid = self._path_to_note(path)
return note['text'][offset : offset + size]
def truncate(self, path, length, fh=None):
note, uuid = self._pathToNote(path)
note, uuid = self._path_to_note(path)
text = note['text'][:length]
self.item_manager.writeNote(uuid, text)
self._syncNow()
self.item_manager.write_note(uuid, text)
self._sync_now()
return 0
def write(self, path, data, offset, fh):
note, uuid = self._pathToNote(path)
note, uuid = self._path_to_note(path)
text = note['text'][:offset] + data
try:
self.item_manager.writeNote(uuid, text)
self.item_manager.write_note(uuid, text)
except UnicodeError:
logging.error('Unable to parse non-unicode data.')
raise FuseOSError(errno.EIO)
self._syncNow()
self._sync_now()
return len(data)
def create(self, path, mode):
@ -123,14 +122,14 @@ class StandardNotesFUSE(LoggingMixIn, Operations):
now = datetime.utcnow().isoformat()[:-3] + 'Z' # hack
self.item_manager.createNote(note_name, now)
self._syncNow()
self.item_manager.create_note(note_name, now)
self._sync_now()
return 0
def unlink(self, path):
note, uuid = self._pathToNote(path)
self.item_manager.deleteNote(uuid)
self._syncNow()
note, uuid = self._path_to_note(path)
self.item_manager.delete_note(uuid)
self._sync_now()
return 0
def mkdir(self, path, mode):
@ -138,17 +137,17 @@ class StandardNotesFUSE(LoggingMixIn, Operations):
raise FuseOSError(errno.EPERM)
def utimens(self, path, times=None):
note, uuid = self._pathToNote(path)
self.item_manager.touchNote(uuid)
self._syncNow()
note, uuid = self._path_to_note(path)
self.item_manager.touch_note(uuid)
self._sync_now()
return 0
def rename(self, old, new):
note, uuid = self._pathToNote(old)
note, uuid = self._path_to_note(old)
new_path_parts = new.split('/')
new_note_name = new_path_parts[1]
self.item_manager.renameNote(uuid, new_note_name)
self._syncNow()
self.item_manager.rename_note(uuid, new_note_name)
self._sync_now()
return 0
def chmod(self, path, mode):

View File

@ -1,18 +1,18 @@
import appdirs
import argparse
from configparser import ConfigParser
from getpass import getpass
import logging
import os
import pathlib
import sys
from configparser import ConfigParser
from getpass import getpass
from api import StandardNotesAPI, SNAPIException
from sn_fuse import StandardNotesFUSE
import appdirs
from fuse import FUSE
from requests.exceptions import ConnectionError, MissingSchema
from api import SNAPIException, StandardNotesAPI
from sn_fuse import StandardNotesFUSE
OFFICIAL_SERVER_URL = 'https://sync.standardnotes.org'
DEFAULT_SYNC_SEC = 30
MINIMUM_SYNC_SEC = 5
@ -38,7 +38,8 @@ def parse_options():
parser.add_argument('--foreground', action='store_true',
help='run standardnotes-fs in the foreground')
parser.add_argument('--sync-sec', type=int, default=DEFAULT_SYNC_SEC,
help='how many seconds between each sync. Default: 10')
help='how many seconds between each sync. Default: '
''+str(DEFAULT_SYNC_SEC))
parser.add_argument('--sync-url',
help='URL of Standard File sync server. Defaults to:\n'
''+OFFICIAL_SERVER_URL)
@ -123,25 +124,25 @@ def main():
log_msg = 'Using sync URL "%s".'
logging.info(log_msg % sync_url)
if config.has_option('user', 'username') \
and config.has_section('keys') \
and not args.username \
and not args.password:
if (config.has_option('user', 'username')
and config.has_section('keys')
and not args.username
and not args.password):
username = config.get('user', 'username')
keys = dict(config.items('keys'))
else:
username = args.username if args.username else \
input('Please enter your Standard Notes username: ')
password = args.password if args.password else \
getpass('Please enter your password (hidden): ')
username = (args.username if args.username else
input('Please enter your Standard Notes username: '))
password = (args.password if args.password else
getpass('Please enter your password (hidden): '))
# log the user in
try:
sn_api = StandardNotesAPI(sync_url, username)
if not keys:
keys = sn_api.genKeys(password)
keys = sn_api.gen_keys(password)
del password
sn_api.signIn(keys)
sn_api.sign_in(keys)
log_msg = 'Successfully logged into account "%s".'
logging.info(log_msg % username)
login_success = True
@ -164,7 +165,6 @@ def main():
keys=keys))
config.write(f)
log_msg = 'Config written to file "%s".'
logging.info(log_msg % str(config_file))
else:
log_msg = 'Clearing config file "%s".'
logging.info(log_msg % config_file)