diff --git a/api.py b/api.py index 13c4023..3c5cb7f 100644 --- a/api.py +++ b/api.py @@ -27,11 +27,15 @@ from cosmopolite.lib import utils import config +class InvalidInstanceID(Exception): + pass + + def CreateChannel(google_user, client, instance_id, args): - instance = models.Instance.FindOrCreate(instance_id, client) + models.Instance.FindOrCreate(instance_id, client) token = channel.create_channel( - client_id=str(instance.id()), + client_id=instance_id, duration_minutes=config.CHANNEL_DURATION_SECONDS / 60) events = [] if google_user: @@ -78,7 +82,10 @@ def SendMessage(google_user, client, instance_id, args): def Subscribe(google_user, client, instance_id, args): - instance = models.Instance.FromID(instance_id, client) + instance = models.Instance.FromID(instance_id) + if not instance or not instance.active: + raise InvalidInstanceID + subject = models.Subject.FindOrCreate(args['subject']) messages = args.get('messages', 0) last_id = args.get('last_id', None) @@ -105,7 +112,7 @@ def Subscribe(google_user, client, instance_id, args): def Unsubscribe(google_user, client, instance_id, args): - instance = models.Instance.FromID(instance_id, client) + instance = models.Instance.FromID(instance_id) subject = models.Subject.FindOrCreate(args['subject']) models.Subscription.Remove(subject, instance) diff --git a/channel.py b/channel.py index 7385e01..7d84690 100644 --- a/channel.py +++ b/channel.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import webapp2 from google.appengine.ext import db @@ -26,7 +27,7 @@ class OnChannelConnect(webapp2.RequestHandler): @db.transactional() def post(self): instance_id = self.request.get('from') - instance = models.Instance.get_by_id(instance_id) + instance = models.Instance.FromID(instance_id) instance.active = True instance.put() @@ -35,7 +36,7 @@ class OnChannelDisconnect(webapp2.RequestHandler): @utils.local_namespace def post(self): instance_id = self.request.get('from') - instance = models.Instance.get_by_id(instance_id) + instance = models.Instance.FromID(instance_id) subscriptions = models.Subscription.all().filter('instance =', instance) for subscription in subscriptions: diff --git a/lib/models.py b/lib/models.py index e02c959..32b221f 100644 --- a/lib/models.py +++ b/lib/models.py @@ -24,7 +24,8 @@ import utils # Profile # ↳ Client -# ↳ Instance +# +# Instance (⤴︎ Client) # # Subject # ↳ Message @@ -86,32 +87,24 @@ class Client(db.Model): class Instance(db.Model): - # parent=Client - - id_ = db.StringProperty(required=True) + client = db.ReferenceProperty(required=True) active = db.BooleanProperty(required=True, default=False) @classmethod @db.transactional() - def FromID(cls, instance_id, client): - instances = ( - cls.all(keys_only=True) - .filter('id_ =', instance_id) - .ancestor(client) - .fetch(1)) - if instances: - return instances[0] - else: - return None + def FromID(cls, instance_id): + # TODO: assert client equality here if possible + return cls.get_by_key_name(instance_id) @classmethod @db.transactional() def FindOrCreate(cls, instance_id, client): - instance = cls.FromID(instance_id, client) + instance = cls.get_by_key_name(instance_id) if instance: + # TODO: assert client equality here return instance else: - return cls(parent=client, id_=instance_id).put() + return cls(key_name=instance_id, client=client).put() class Subject(db.Model): @@ -252,8 +245,10 @@ class Subscription(db.Model): def FindOrCreate(cls, subject, instance, messages=0, last_id=None): readable_only_by = ( Subject.readable_only_by.get_value_for_datastore(subject)) + client_key = ( + Instance.client.get_value_for_datastore(instance)) if (readable_only_by and - readable_only_by != instance.parent().parent()): + readable_only_by != client_key.parent()): raise AccessDenied subscriptions = ( @@ -283,7 +278,7 @@ class Subscription(db.Model): def SendMessage(self, msg): instance_key = Subscription.instance.get_value_for_datastore(self) channel.send_message( - str(instance_key.id()), + str(instance_key.name()), json.dumps(msg, default=utils.EncodeJSON))