Package app: Wrap oauth2.TokenSource to ensure datastore is always updated.
[kraftakt.git] / app / user.go
1 package app
2
3 import (
4         "context"
5         "fmt"
6         "net/http"
7
8         "github.com/google/uuid"
9         legacy_context "golang.org/x/net/context"
10         "golang.org/x/oauth2"
11         "google.golang.org/appengine/datastore"
12         "google.golang.org/appengine/log"
13 )
14
15 type User struct {
16         Email string
17         key   *datastore.Key
18 }
19
20 type dbUser struct {
21         ID string
22 }
23
24 func NewUser(ctx context.Context, email string) (*User, error) {
25         err := datastore.RunInTransaction(ctx, func(ctx legacy_context.Context) error {
26                 key := datastore.NewKey(ctx, "User", email, 0, nil)
27                 if err := datastore.Get(ctx, key, &dbUser{}); err != datastore.ErrNoSuchEntity {
28                         return err // may be nil
29                 }
30
31                 _, err := datastore.Put(ctx, key, &dbUser{
32                         ID: uuid.New().String(),
33                 })
34                 return err
35         }, nil)
36         if err != nil {
37                 return nil, err
38         }
39
40         return &User{
41                 Email: email,
42                 key:   datastore.NewKey(ctx, "User", email, 0, nil),
43         }, nil
44 }
45
46 func UserByID(ctx context.Context, id string) (*User, error) {
47         q := datastore.NewQuery("User").Filter("ID=", id).KeysOnly()
48         keys, err := q.GetAll(ctx, nil)
49         if err != nil {
50                 return nil, fmt.Errorf("datastore.Query.GetAll(): %v", err)
51         }
52         if len(keys) != 1 {
53                 return nil, fmt.Errorf("len(keys) = %d, want 1", len(keys))
54         }
55
56         return &User{
57                 Email: keys[0].StringID(),
58                 key:   keys[0],
59         }, nil
60 }
61
62 func (u *User) ID(ctx context.Context) (string, error) {
63         var db dbUser
64         if err := datastore.Get(ctx, u.key, &db); err != nil {
65                 return "", err
66         }
67
68         return db.ID, nil
69 }
70
71 func (u *User) Token(ctx context.Context, svc string) (*oauth2.Token, error) {
72         key := datastore.NewKey(ctx, "Token", svc, 0, u.key)
73
74         var tok oauth2.Token
75         if err := datastore.Get(ctx, key, &tok); err != nil {
76                 return nil, err
77         }
78
79         return &tok, nil
80 }
81
82 func (u *User) SetToken(ctx context.Context, svc string, tok *oauth2.Token) error {
83         key := datastore.NewKey(ctx, "Token", svc, 0, u.key)
84         _, err := datastore.Put(ctx, key, tok)
85         return err
86 }
87
88 func (u *User) OAuthClient(ctx context.Context, svc string, cfg *oauth2.Config) (*http.Client, error) {
89         key := datastore.NewKey(ctx, "Token", svc, 0, u.key)
90
91         var tok oauth2.Token
92         if err := datastore.Get(ctx, key, &tok); err != nil {
93                 return nil, err
94         }
95
96         src := cfg.TokenSource(ctx, &tok)
97         return oauth2.NewClient(ctx, &persistingTokenSource{
98                 ctx: ctx,
99                 t:   &tok,
100                 src: src,
101                 key: key,
102         }), nil
103 }
104
105 type persistingTokenSource struct {
106         ctx context.Context
107         t   *oauth2.Token
108         src oauth2.TokenSource
109         key *datastore.Key
110 }
111
112 func (s *persistingTokenSource) Token() (*oauth2.Token, error) {
113         tok, err := s.src.Token()
114         if err != nil {
115                 return nil, err
116         }
117
118         if s.t.RefreshToken != tok.RefreshToken {
119                 if _, err := datastore.Put(s.ctx, s.key, tok); err != nil {
120                         log.Errorf(s.ctx, "persisting OAuth token in datastore failed: %v", err)
121                 }
122         }
123
124         s.t = tok
125         return tok, nil
126 }