Remove x/net/context vendor by using std package (#5202)

* Update dep github.com/markbates/goth

* Update dep github.com/blevesearch/bleve

* Update dep golang.org/x/oauth2

* Fix github.com/blevesearch/bleve to c74e08f039e56cef576e4336382b2a2d12d9e026

* Update dep golang.org/x/oauth2
This commit is contained in:
Antoine GIRARD 2018-11-11 00:55:36 +01:00 committed by techknowlogick
parent b3000ae623
commit 4c1f1f9646
40 changed files with 1492 additions and 644 deletions

18
Gopkg.lock generated
View File

@ -90,7 +90,7 @@
revision = "3a771d992973f24aa725d07868b467d1ddfceafb" revision = "3a771d992973f24aa725d07868b467d1ddfceafb"
[[projects]] [[projects]]
digest = "1:67351095005f164e748a5a21899d1403b03878cb2d40a7b0f742376e6eeda974" digest = "1:c10f35be6200b09e26da267ca80f837315093ecaba27e7a223071380efb9dd32"
name = "github.com/blevesearch/bleve" name = "github.com/blevesearch/bleve"
packages = [ packages = [
".", ".",
@ -135,7 +135,7 @@
"search/searcher", "search/searcher",
] ]
pruneopts = "NUT" pruneopts = "NUT"
revision = "ff210fbc6d348ad67aa5754eaea11a463fcddafd" revision = "c74e08f039e56cef576e4336382b2a2d12d9e026"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -557,7 +557,7 @@
revision = "e3534c89ef969912856dfa39e56b09e58c5f5daf" revision = "e3534c89ef969912856dfa39e56b09e58c5f5daf"
[[projects]] [[projects]]
digest = "1:23f75ae90fcc38dac6fad6881006ea7d0f2c78db5f9f81f3df558dc91460e61f" digest = "1:4b992ec853d0ea9bac3dcf09a64af61de1a392e6cb0eef2204c0c92f4ae6b911"
name = "github.com/markbates/goth" name = "github.com/markbates/goth"
packages = [ packages = [
".", ".",
@ -572,8 +572,8 @@
"providers/twitter", "providers/twitter",
] ]
pruneopts = "NUT" pruneopts = "NUT"
revision = "f9c6649ab984d6ea71ef1e13b7b1cdffcf4592d3" revision = "bc6d8ddf751a745f37ca5567dbbfc4157bbf5da9"
version = "v1.46.1" version = "v1.47.2"
[[projects]] [[projects]]
digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5" digest = "1:c9724c929d27a14475a45b17a267dbc60671c0bc2c5c05ed21f011f7b5bc9fb5"
@ -809,10 +809,11 @@
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:6d5ed712653ea5321fe3e3475ab2188cf362a4e0d31e9fd3acbd4dfbbca0d680" digest = "1:d0a0bdd2b64d981aa4e6a1ade90431d042cd7fa31b584e33d45e62cbfec43380"
name = "golang.org/x/net" name = "golang.org/x/net"
packages = [ packages = [
"context", "context",
"context/ctxhttp",
"html", "html",
"html/atom", "html/atom",
"html/charset", "html/charset",
@ -821,14 +822,15 @@
revision = "9b4f9f5ad5197c79fd623a3638e70d8b26cef344" revision = "9b4f9f5ad5197c79fd623a3638e70d8b26cef344"
[[projects]] [[projects]]
digest = "1:8159a9cda4b8810aaaeb0d60e2fa68e2fd86d8af4ec8f5059830839e3c8d93d5" branch = "master"
digest = "1:274a6321a5a9f185eeb3fab5d7d8397e0e9f57737490d749f562c7e205ffbc2e"
name = "golang.org/x/oauth2" name = "golang.org/x/oauth2"
packages = [ packages = [
".", ".",
"internal", "internal",
] ]
pruneopts = "NUT" pruneopts = "NUT"
revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061" revision = "c453e0c757598fd055e170a3a359263c91e13153"
[[projects]] [[projects]]
digest = "1:9f303486d623f840492bfeb48eb906a94e9d3fe638a761639b72ce64bf7bfcc3" digest = "1:9f303486d623f840492bfeb48eb906a94e9d3fe638a761639b72ce64bf7bfcc3"

View File

@ -14,6 +14,12 @@ ignored = ["google.golang.org/appengine*"]
branch = "master" branch = "master"
name = "code.gitea.io/sdk" name = "code.gitea.io/sdk"
[[constraint]]
# branch = "master"
revision = "c74e08f039e56cef576e4336382b2a2d12d9e026"
name = "github.com/blevesearch/bleve"
#Not targetting v0.7.0 since standard where use only just after this tag
[[constraint]] [[constraint]]
revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e" revision = "12dd70caea0268ac0d6c2707d0611ef601e7c64e"
name = "golang.org/x/crypto" name = "golang.org/x/crypto"
@ -61,7 +67,7 @@ ignored = ["google.golang.org/appengine*"]
[[constraint]] [[constraint]]
name = "github.com/markbates/goth" name = "github.com/markbates/goth"
version = "1.46.1" version = "1.47.2"
[[constraint]] [[constraint]]
branch = "master" branch = "master"
@ -105,7 +111,7 @@ ignored = ["google.golang.org/appengine*"]
source = "github.com/go-gitea/bolt" source = "github.com/go-gitea/bolt"
[[override]] [[override]]
revision = "c10ba270aa0bf8b8c1c986e103859c67a9103061" branch = "master"
name = "golang.org/x/oauth2" name = "golang.org/x/oauth2"
[[constraint]] [[constraint]]

View File

@ -15,11 +15,12 @@
package bleve package bleve
import ( import (
"context"
"github.com/blevesearch/bleve/document" "github.com/blevesearch/bleve/document"
"github.com/blevesearch/bleve/index" "github.com/blevesearch/bleve/index"
"github.com/blevesearch/bleve/index/store" "github.com/blevesearch/bleve/index/store"
"github.com/blevesearch/bleve/mapping" "github.com/blevesearch/bleve/mapping"
"golang.org/x/net/context"
) )
// A Batch groups together multiple Index and Delete // A Batch groups together multiple Index and Delete

View File

@ -100,8 +100,8 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error {
// prepare new index snapshot // prepare new index snapshot
newSnapshot := &IndexSnapshot{ newSnapshot := &IndexSnapshot{
parent: s, parent: s,
segment: make([]*SegmentSnapshot, nsegs, nsegs+1), segment: make([]*SegmentSnapshot, 0, nsegs+1),
offsets: make([]uint64, nsegs, nsegs+1), offsets: make([]uint64, 0, nsegs+1),
internal: make(map[string][]byte, len(s.root.internal)), internal: make(map[string][]byte, len(s.root.internal)),
epoch: s.nextSnapshotEpoch, epoch: s.nextSnapshotEpoch,
refs: 1, refs: 1,
@ -124,24 +124,29 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error {
return err return err
} }
} }
newSnapshot.segment[i] = &SegmentSnapshot{
newss := &SegmentSnapshot{
id: s.root.segment[i].id, id: s.root.segment[i].id,
segment: s.root.segment[i].segment, segment: s.root.segment[i].segment,
cachedDocs: s.root.segment[i].cachedDocs, cachedDocs: s.root.segment[i].cachedDocs,
} }
s.root.segment[i].segment.AddRef()
// apply new obsoletions // apply new obsoletions
if s.root.segment[i].deleted == nil { if s.root.segment[i].deleted == nil {
newSnapshot.segment[i].deleted = delta newss.deleted = delta
} else { } else {
newSnapshot.segment[i].deleted = roaring.Or(s.root.segment[i].deleted, delta) newss.deleted = roaring.Or(s.root.segment[i].deleted, delta)
} }
newSnapshot.offsets[i] = running // check for live size before copying
if newss.LiveSize() > 0 {
newSnapshot.segment = append(newSnapshot.segment, newss)
s.root.segment[i].segment.AddRef()
newSnapshot.offsets = append(newSnapshot.offsets, running)
running += s.root.segment[i].Count() running += s.root.segment[i].Count()
} }
}
// append new segment, if any, to end of the new index snapshot // append new segment, if any, to end of the new index snapshot
if next.data != nil { if next.data != nil {
newSegmentSnapshot := &SegmentSnapshot{ newSegmentSnapshot := &SegmentSnapshot{
@ -193,6 +198,12 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
// prepare new index snapshot // prepare new index snapshot
currSize := len(s.root.segment) currSize := len(s.root.segment)
newSize := currSize + 1 - len(nextMerge.old) newSize := currSize + 1 - len(nextMerge.old)
// empty segments deletion
if nextMerge.new == nil {
newSize--
}
newSnapshot := &IndexSnapshot{ newSnapshot := &IndexSnapshot{
parent: s, parent: s,
segment: make([]*SegmentSnapshot, 0, newSize), segment: make([]*SegmentSnapshot, 0, newSize),
@ -210,7 +221,7 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
segmentID := s.root.segment[i].id segmentID := s.root.segment[i].id
if segSnapAtMerge, ok := nextMerge.old[segmentID]; ok { if segSnapAtMerge, ok := nextMerge.old[segmentID]; ok {
// this segment is going away, see if anything else was deleted since we started the merge // this segment is going away, see if anything else was deleted since we started the merge
if s.root.segment[i].deleted != nil { if segSnapAtMerge != nil && s.root.segment[i].deleted != nil {
// assume all these deletes are new // assume all these deletes are new
deletedSince := s.root.segment[i].deleted deletedSince := s.root.segment[i].deleted
// if we already knew about some of them, remove // if we already knew about some of them, remove
@ -224,7 +235,13 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
newSegmentDeleted.Add(uint32(newDocNum)) newSegmentDeleted.Add(uint32(newDocNum))
} }
} }
} else { // clean up the old segment map to figure out the
// obsolete segments wrt root in meantime, whatever
// segments left behind in old map after processing
// the root segments would be the obsolete segment set
delete(nextMerge.old, segmentID)
} else if s.root.segment[i].LiveSize() > 0 {
// this segment is staying // this segment is staying
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
id: s.root.segment[i].id, id: s.root.segment[i].id,
@ -238,6 +255,24 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
} }
} }
// before the newMerge introduction, need to clean the newly
// merged segment wrt the current root segments, hence
// applying the obsolete segment contents to newly merged segment
for segID, ss := range nextMerge.old {
obsoleted := ss.DocNumbersLive()
if obsoleted != nil {
obsoletedIter := obsoleted.Iterator()
for obsoletedIter.HasNext() {
oldDocNum := obsoletedIter.Next()
newDocNum := nextMerge.oldNewDocNums[segID][oldDocNum]
newSegmentDeleted.Add(uint32(newDocNum))
}
}
}
// In case where all the docs in the newly merged segment getting
// deleted by the time we reach here, can skip the introduction.
if nextMerge.new != nil &&
nextMerge.new.Count() > newSegmentDeleted.GetCardinality() {
// put new segment at end // put new segment at end
newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{
id: nextMerge.id, id: nextMerge.id,
@ -246,6 +281,9 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
cachedDocs: &cachedDocs{cache: nil}, cachedDocs: &cachedDocs{cache: nil},
}) })
newSnapshot.offsets = append(newSnapshot.offsets, running) newSnapshot.offsets = append(newSnapshot.offsets, running)
}
newSnapshot.AddRef() // 1 ref for the nextMerge.notify response
// swap in new segment // swap in new segment
rootPrev := s.root rootPrev := s.root
@ -257,7 +295,8 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) {
_ = rootPrev.DecRef() _ = rootPrev.DecRef()
} }
// notify merger we incorporated this // notify requester that we incorporated this
nextMerge.notify <- newSnapshot
close(nextMerge.notify) close(nextMerge.notify)
} }

View File

@ -15,6 +15,9 @@
package scorch package scorch
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"os" "os"
"sync/atomic" "sync/atomic"
@ -28,6 +31,13 @@ import (
func (s *Scorch) mergerLoop() { func (s *Scorch) mergerLoop() {
var lastEpochMergePlanned uint64 var lastEpochMergePlanned uint64
mergePlannerOptions, err := s.parseMergePlannerOptions()
if err != nil {
s.fireAsyncError(fmt.Errorf("mergePlannerOption json parsing err: %v", err))
s.asyncTasks.Done()
return
}
OUTER: OUTER:
for { for {
select { select {
@ -45,7 +55,7 @@ OUTER:
startTime := time.Now() startTime := time.Now()
// lets get started // lets get started
err := s.planMergeAtSnapshot(ourSnapshot) err := s.planMergeAtSnapshot(ourSnapshot, mergePlannerOptions)
if err != nil { if err != nil {
s.fireAsyncError(fmt.Errorf("merging err: %v", err)) s.fireAsyncError(fmt.Errorf("merging err: %v", err))
_ = ourSnapshot.DecRef() _ = ourSnapshot.DecRef()
@ -58,51 +68,49 @@ OUTER:
_ = ourSnapshot.DecRef() _ = ourSnapshot.DecRef()
// tell the persister we're waiting for changes // tell the persister we're waiting for changes
// first make a notification chan // first make a epochWatcher chan
notifyUs := make(notificationChan) ew := &epochWatcher{
epoch: lastEpochMergePlanned,
notifyCh: make(notificationChan, 1),
}
// give it to the persister // give it to the persister
select { select {
case <-s.closeCh: case <-s.closeCh:
break OUTER break OUTER
case s.persisterNotifier <- notifyUs: case s.persisterNotifier <- ew:
} }
// check again // now wait for persister (but also detect close)
s.rootLock.RLock()
ourSnapshot = s.root
ourSnapshot.AddRef()
s.rootLock.RUnlock()
if ourSnapshot.epoch != lastEpochMergePlanned {
startTime := time.Now()
// lets get started
err := s.planMergeAtSnapshot(ourSnapshot)
if err != nil {
s.fireAsyncError(fmt.Errorf("merging err: %v", err))
_ = ourSnapshot.DecRef()
continue OUTER
}
lastEpochMergePlanned = ourSnapshot.epoch
s.fireEvent(EventKindMergerProgress, time.Since(startTime))
}
_ = ourSnapshot.DecRef()
// now wait for it (but also detect close)
select { select {
case <-s.closeCh: case <-s.closeCh:
break OUTER break OUTER
case <-notifyUs: case <-ew.notifyCh:
// woken up, next loop should pick up work
} }
} }
} }
s.asyncTasks.Done() s.asyncTasks.Done()
} }
func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error { func (s *Scorch) parseMergePlannerOptions() (*mergeplan.MergePlanOptions,
error) {
mergePlannerOptions := mergeplan.DefaultMergePlanOptions
if v, ok := s.config["scorchMergePlanOptions"]; ok {
b, err := json.Marshal(v)
if err != nil {
return &mergePlannerOptions, err
}
err = json.Unmarshal(b, &mergePlannerOptions)
if err != nil {
return &mergePlannerOptions, err
}
}
return &mergePlannerOptions, nil
}
func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot,
options *mergeplan.MergePlanOptions) error {
// build list of zap segments in this snapshot // build list of zap segments in this snapshot
var onlyZapSnapshots []mergeplan.Segment var onlyZapSnapshots []mergeplan.Segment
for _, segmentSnapshot := range ourSnapshot.segment { for _, segmentSnapshot := range ourSnapshot.segment {
@ -112,7 +120,7 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
} }
// give this list to the planner // give this list to the planner
resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, nil) resultMergePlan, err := mergeplan.Plan(onlyZapSnapshots, options)
if err != nil { if err != nil {
return fmt.Errorf("merge planning err: %v", err) return fmt.Errorf("merge planning err: %v", err)
} }
@ -122,8 +130,12 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
} }
// process tasks in serial for now // process tasks in serial for now
var notifications []notificationChan var notifications []chan *IndexSnapshot
for _, task := range resultMergePlan.Tasks { for _, task := range resultMergePlan.Tasks {
if len(task.Segments) == 0 {
continue
}
oldMap := make(map[uint64]*SegmentSnapshot) oldMap := make(map[uint64]*SegmentSnapshot)
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1) newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1)
segmentsToMerge := make([]*zap.Segment, 0, len(task.Segments)) segmentsToMerge := make([]*zap.Segment, 0, len(task.Segments))
@ -132,40 +144,51 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok { if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok {
oldMap[segSnapshot.id] = segSnapshot oldMap[segSnapshot.id] = segSnapshot
if zapSeg, ok := segSnapshot.segment.(*zap.Segment); ok { if zapSeg, ok := segSnapshot.segment.(*zap.Segment); ok {
if segSnapshot.LiveSize() == 0 {
oldMap[segSnapshot.id] = nil
} else {
segmentsToMerge = append(segmentsToMerge, zapSeg) segmentsToMerge = append(segmentsToMerge, zapSeg)
docsToDrop = append(docsToDrop, segSnapshot.deleted) docsToDrop = append(docsToDrop, segSnapshot.deleted)
} }
} }
} }
}
var oldNewDocNums map[uint64][]uint64
var segment segment.Segment
if len(segmentsToMerge) > 0 {
filename := zapFileName(newSegmentID) filename := zapFileName(newSegmentID)
s.markIneligibleForRemoval(filename) s.markIneligibleForRemoval(filename)
path := s.path + string(os.PathSeparator) + filename path := s.path + string(os.PathSeparator) + filename
newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, DefaultChunkFactor) newDocNums, err := zap.Merge(segmentsToMerge, docsToDrop, path, 1024)
if err != nil { if err != nil {
s.unmarkIneligibleForRemoval(filename) s.unmarkIneligibleForRemoval(filename)
return fmt.Errorf("merging failed: %v", err) return fmt.Errorf("merging failed: %v", err)
} }
segment, err := zap.Open(path) segment, err = zap.Open(path)
if err != nil { if err != nil {
s.unmarkIneligibleForRemoval(filename) s.unmarkIneligibleForRemoval(filename)
return err return err
} }
oldNewDocNums = make(map[uint64][]uint64)
for i, segNewDocNums := range newDocNums {
oldNewDocNums[task.Segments[i].Id()] = segNewDocNums
}
}
sm := &segmentMerge{ sm := &segmentMerge{
id: newSegmentID, id: newSegmentID,
old: oldMap, old: oldMap,
oldNewDocNums: make(map[uint64][]uint64), oldNewDocNums: oldNewDocNums,
new: segment, new: segment,
notify: make(notificationChan), notify: make(chan *IndexSnapshot, 1),
} }
notifications = append(notifications, sm.notify) notifications = append(notifications, sm.notify)
for i, segNewDocNums := range newDocNums {
sm.oldNewDocNums[task.Segments[i].Id()] = segNewDocNums
}
// give it to the introducer // give it to the introducer
select { select {
case <-s.closeCh: case <-s.closeCh:
_ = segment.Close()
return nil return nil
case s.merges <- sm: case s.merges <- sm:
} }
@ -174,7 +197,10 @@ func (s *Scorch) planMergeAtSnapshot(ourSnapshot *IndexSnapshot) error {
select { select {
case <-s.closeCh: case <-s.closeCh:
return nil return nil
case <-notification: case newSnapshot := <-notification:
if newSnapshot != nil {
_ = newSnapshot.DecRef()
}
} }
} }
return nil return nil
@ -185,5 +211,72 @@ type segmentMerge struct {
old map[uint64]*SegmentSnapshot old map[uint64]*SegmentSnapshot
oldNewDocNums map[uint64][]uint64 oldNewDocNums map[uint64][]uint64
new segment.Segment new segment.Segment
notify notificationChan notify chan *IndexSnapshot
}
// perform a merging of the given SegmentBase instances into a new,
// persisted segment, and synchronously introduce that new segment
// into the root
func (s *Scorch) mergeSegmentBases(snapshot *IndexSnapshot,
sbs []*zap.SegmentBase, sbsDrops []*roaring.Bitmap, sbsIndexes []int,
chunkFactor uint32) (uint64, *IndexSnapshot, uint64, error) {
var br bytes.Buffer
cr := zap.NewCountHashWriter(&br)
newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset,
docValueOffset, dictLocs, fieldsInv, fieldsMap, err :=
zap.MergeToWriter(sbs, sbsDrops, chunkFactor, cr)
if err != nil {
return 0, nil, 0, err
}
sb, err := zap.InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor,
fieldsMap, fieldsInv, numDocs, storedIndexOffset, fieldsIndexOffset,
docValueOffset, dictLocs)
if err != nil {
return 0, nil, 0, err
}
newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1)
filename := zapFileName(newSegmentID)
path := s.path + string(os.PathSeparator) + filename
err = zap.PersistSegmentBase(sb, path)
if err != nil {
return 0, nil, 0, err
}
segment, err := zap.Open(path)
if err != nil {
return 0, nil, 0, err
}
sm := &segmentMerge{
id: newSegmentID,
old: make(map[uint64]*SegmentSnapshot),
oldNewDocNums: make(map[uint64][]uint64),
new: segment,
notify: make(chan *IndexSnapshot, 1),
}
for i, idx := range sbsIndexes {
ss := snapshot.segment[idx]
sm.old[ss.id] = ss
sm.oldNewDocNums[ss.id] = newDocNums[i]
}
select { // send to introducer
case <-s.closeCh:
_ = segment.DecRef()
return 0, nil, 0, nil // TODO: return ErrInterruptedClosed?
case s.merges <- sm:
}
select { // wait for introduction to complete
case <-s.closeCh:
return 0, nil, 0, nil // TODO: return ErrInterruptedClosed?
case newSnapshot := <-sm.notify:
return numDocs, newSnapshot, newSegmentID, nil
}
} }

View File

@ -186,13 +186,13 @@ func plan(segmentsIn []Segment, o *MergePlanOptions) (*MergePlan, error) {
// While were over budget, keep looping, which might produce // While were over budget, keep looping, which might produce
// another MergeTask. // another MergeTask.
for len(eligibles) > budgetNumSegments { for len(eligibles) > 0 && (len(eligibles)+len(rv.Tasks)) > budgetNumSegments {
// Track a current best roster as we examine and score // Track a current best roster as we examine and score
// potential rosters of merges. // potential rosters of merges.
var bestRoster []Segment var bestRoster []Segment
var bestRosterScore float64 // Lower score is better. var bestRosterScore float64 // Lower score is better.
for startIdx := 0; startIdx < len(eligibles)-o.SegmentsPerMergeTask; startIdx++ { for startIdx := 0; startIdx < len(eligibles); startIdx++ {
var roster []Segment var roster []Segment
var rosterLiveSize int64 var rosterLiveSize int64

View File

@ -34,22 +34,39 @@ import (
var DefaultChunkFactor uint32 = 1024 var DefaultChunkFactor uint32 = 1024
// Arbitrary number, need to make it configurable.
// Lower values like 10/making persister really slow
// doesn't work well as it is creating more files to
// persist for in next persist iteration and spikes the # FDs.
// Ideal value should let persister also proceed at
// an optimum pace so that the merger can skip
// many intermediate snapshots.
// This needs to be based on empirical data.
// TODO - may need to revisit this approach/value.
var epochDistance = uint64(5)
type notificationChan chan struct{} type notificationChan chan struct{}
func (s *Scorch) persisterLoop() { func (s *Scorch) persisterLoop() {
defer s.asyncTasks.Done() defer s.asyncTasks.Done()
var notifyChs []notificationChan var persistWatchers []*epochWatcher
var lastPersistedEpoch uint64 var lastPersistedEpoch, lastMergedEpoch uint64
var ew *epochWatcher
OUTER: OUTER:
for { for {
select { select {
case <-s.closeCh: case <-s.closeCh:
break OUTER break OUTER
case notifyCh := <-s.persisterNotifier: case ew = <-s.persisterNotifier:
notifyChs = append(notifyChs, notifyCh) persistWatchers = append(persistWatchers, ew)
default: default:
} }
if ew != nil && ew.epoch > lastMergedEpoch {
lastMergedEpoch = ew.epoch
}
persistWatchers = s.pausePersisterForMergerCatchUp(lastPersistedEpoch,
&lastMergedEpoch, persistWatchers)
var ourSnapshot *IndexSnapshot var ourSnapshot *IndexSnapshot
var ourPersisted []chan error var ourPersisted []chan error
@ -81,10 +98,11 @@ OUTER:
} }
lastPersistedEpoch = ourSnapshot.epoch lastPersistedEpoch = ourSnapshot.epoch
for _, notifyCh := range notifyChs { for _, ew := range persistWatchers {
close(notifyCh) close(ew.notifyCh)
} }
notifyChs = nil
persistWatchers = nil
_ = ourSnapshot.DecRef() _ = ourSnapshot.DecRef()
changed := false changed := false
@ -120,27 +138,155 @@ OUTER:
break OUTER break OUTER
case <-w.notifyCh: case <-w.notifyCh:
// woken up, next loop should pick up work // woken up, next loop should pick up work
continue OUTER
case ew = <-s.persisterNotifier:
// if the watchers are already caught up then let them wait,
// else let them continue to do the catch up
persistWatchers = append(persistWatchers, ew)
} }
} }
} }
func notifyMergeWatchers(lastPersistedEpoch uint64,
persistWatchers []*epochWatcher) []*epochWatcher {
var watchersNext []*epochWatcher
for _, w := range persistWatchers {
if w.epoch < lastPersistedEpoch {
close(w.notifyCh)
} else {
watchersNext = append(watchersNext, w)
}
}
return watchersNext
}
func (s *Scorch) pausePersisterForMergerCatchUp(lastPersistedEpoch uint64, lastMergedEpoch *uint64,
persistWatchers []*epochWatcher) []*epochWatcher {
// first, let the watchers proceed if they lag behind
persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers)
OUTER:
// check for slow merger and await until the merger catch up
for lastPersistedEpoch > *lastMergedEpoch+epochDistance {
select {
case <-s.closeCh:
break OUTER
case ew := <-s.persisterNotifier:
persistWatchers = append(persistWatchers, ew)
*lastMergedEpoch = ew.epoch
}
// let the watchers proceed if they lag behind
persistWatchers = notifyMergeWatchers(lastPersistedEpoch, persistWatchers)
}
return persistWatchers
}
func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error { func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
persisted, err := s.persistSnapshotMaybeMerge(snapshot)
if err != nil {
return err
}
if persisted {
return nil
}
return s.persistSnapshotDirect(snapshot)
}
// DefaultMinSegmentsForInMemoryMerge represents the default number of
// in-memory zap segments that persistSnapshotMaybeMerge() needs to
// see in an IndexSnapshot before it decides to merge and persist
// those segments
var DefaultMinSegmentsForInMemoryMerge = 2
// persistSnapshotMaybeMerge examines the snapshot and might merge and
// persist the in-memory zap segments if there are enough of them
func (s *Scorch) persistSnapshotMaybeMerge(snapshot *IndexSnapshot) (
bool, error) {
// collect the in-memory zap segments (SegmentBase instances)
var sbs []*zap.SegmentBase
var sbsDrops []*roaring.Bitmap
var sbsIndexes []int
for i, segmentSnapshot := range snapshot.segment {
if sb, ok := segmentSnapshot.segment.(*zap.SegmentBase); ok {
sbs = append(sbs, sb)
sbsDrops = append(sbsDrops, segmentSnapshot.deleted)
sbsIndexes = append(sbsIndexes, i)
}
}
if len(sbs) < DefaultMinSegmentsForInMemoryMerge {
return false, nil
}
_, newSnapshot, newSegmentID, err := s.mergeSegmentBases(
snapshot, sbs, sbsDrops, sbsIndexes, DefaultChunkFactor)
if err != nil {
return false, err
}
if newSnapshot == nil {
return false, nil
}
defer func() {
_ = newSnapshot.DecRef()
}()
mergedSegmentIDs := map[uint64]struct{}{}
for _, idx := range sbsIndexes {
mergedSegmentIDs[snapshot.segment[idx].id] = struct{}{}
}
// construct a snapshot that's logically equivalent to the input
// snapshot, but with merged segments replaced by the new segment
equiv := &IndexSnapshot{
parent: snapshot.parent,
segment: make([]*SegmentSnapshot, 0, len(snapshot.segment)),
internal: snapshot.internal,
epoch: snapshot.epoch,
}
// copy to the equiv the segments that weren't replaced
for _, segment := range snapshot.segment {
if _, wasMerged := mergedSegmentIDs[segment.id]; !wasMerged {
equiv.segment = append(equiv.segment, segment)
}
}
// append to the equiv the new segment
for _, segment := range newSnapshot.segment {
if segment.id == newSegmentID {
equiv.segment = append(equiv.segment, &SegmentSnapshot{
id: newSegmentID,
segment: segment.segment,
deleted: nil, // nil since merging handled deletions
})
break
}
}
err = s.persistSnapshotDirect(equiv)
if err != nil {
return false, err
}
return true, nil
}
func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot) (err error) {
// start a write transaction // start a write transaction
tx, err := s.rootBolt.Begin(true) tx, err := s.rootBolt.Begin(true)
if err != nil { if err != nil {
return err return err
} }
// defer fsync of the rootbolt // defer rollback on error
defer func() { defer func() {
if err == nil { if err != nil {
err = s.rootBolt.Sync()
}
}()
// defer commit/rollback transaction
defer func() {
if err == nil {
err = tx.Commit()
} else {
_ = tx.Rollback() _ = tx.Rollback()
} }
}() }()
@ -172,20 +318,20 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
newSegmentPaths := make(map[uint64]string) newSegmentPaths := make(map[uint64]string)
// first ensure that each segment in this snapshot has been persisted // first ensure that each segment in this snapshot has been persisted
for i, segmentSnapshot := range snapshot.segment { for _, segmentSnapshot := range snapshot.segment {
snapshotSegmentKey := segment.EncodeUvarintAscending(nil, uint64(i)) snapshotSegmentKey := segment.EncodeUvarintAscending(nil, segmentSnapshot.id)
snapshotSegmentBucket, err2 := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey) snapshotSegmentBucket, err := snapshotBucket.CreateBucketIfNotExists(snapshotSegmentKey)
if err2 != nil { if err != nil {
return err2 return err
} }
switch seg := segmentSnapshot.segment.(type) { switch seg := segmentSnapshot.segment.(type) {
case *zap.SegmentBase: case *zap.SegmentBase:
// need to persist this to disk // need to persist this to disk
filename := zapFileName(segmentSnapshot.id) filename := zapFileName(segmentSnapshot.id)
path := s.path + string(os.PathSeparator) + filename path := s.path + string(os.PathSeparator) + filename
err2 := zap.PersistSegmentBase(seg, path) err = zap.PersistSegmentBase(seg, path)
if err2 != nil { if err != nil {
return fmt.Errorf("error persisting segment: %v", err2) return fmt.Errorf("error persisting segment: %v", err)
} }
newSegmentPaths[segmentSnapshot.id] = path newSegmentPaths[segmentSnapshot.id] = path
err = snapshotSegmentBucket.Put(boltPathKey, []byte(filename)) err = snapshotSegmentBucket.Put(boltPathKey, []byte(filename))
@ -218,19 +364,28 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
} }
} }
// only alter the root if we actually persisted a segment // we need to swap in a new root only when we've persisted 1 or
// (sometimes its just a new snapshot, possibly with new internal values) // more segments -- whereby the new root would have 1-for-1
// replacements of in-memory segments with file-based segments
//
// other cases like updates to internal values only, and/or when
// there are only deletions, are already covered and persisted by
// the newly populated boltdb snapshotBucket above
if len(newSegmentPaths) > 0 { if len(newSegmentPaths) > 0 {
// now try to open all the new snapshots // now try to open all the new snapshots
newSegments := make(map[uint64]segment.Segment) newSegments := make(map[uint64]segment.Segment)
defer func() {
for _, s := range newSegments {
if s != nil {
// cleanup segments that were opened but not
// swapped into the new root
_ = s.Close()
}
}
}()
for segmentID, path := range newSegmentPaths { for segmentID, path := range newSegmentPaths {
newSegments[segmentID], err = zap.Open(path) newSegments[segmentID], err = zap.Open(path)
if err != nil { if err != nil {
for _, s := range newSegments {
if s != nil {
_ = s.Close() // cleanup segments that were successfully opened
}
}
return fmt.Errorf("error opening new segment at %s, %v", path, err) return fmt.Errorf("error opening new segment at %s, %v", path, err)
} }
} }
@ -255,6 +410,7 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
cachedDocs: segmentSnapshot.cachedDocs, cachedDocs: segmentSnapshot.cachedDocs,
} }
newIndexSnapshot.segment[i] = newSegmentSnapshot newIndexSnapshot.segment[i] = newSegmentSnapshot
delete(newSegments, segmentSnapshot.id)
// update items persisted incase of a new segment snapshot // update items persisted incase of a new segment snapshot
atomic.AddUint64(&s.stats.numItemsPersisted, newSegmentSnapshot.Count()) atomic.AddUint64(&s.stats.numItemsPersisted, newSegmentSnapshot.Count())
} else { } else {
@ -266,9 +422,7 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
for k, v := range s.root.internal { for k, v := range s.root.internal {
newIndexSnapshot.internal[k] = v newIndexSnapshot.internal[k] = v
} }
for _, filename := range filenames {
delete(s.ineligibleForRemoval, filename)
}
rootPrev := s.root rootPrev := s.root
s.root = newIndexSnapshot s.root = newIndexSnapshot
s.rootLock.Unlock() s.rootLock.Unlock()
@ -277,6 +431,24 @@ func (s *Scorch) persistSnapshot(snapshot *IndexSnapshot) error {
} }
} }
err = tx.Commit()
if err != nil {
return err
}
err = s.rootBolt.Sync()
if err != nil {
return err
}
// allow files to become eligible for removal after commit, such
// as file segments from snapshots that came from the merger
s.rootLock.Lock()
for _, filename := range filenames {
delete(s.ineligibleForRemoval, filename)
}
s.rootLock.Unlock()
return nil return nil
} }

View File

@ -61,7 +61,7 @@ type Scorch struct {
merges chan *segmentMerge merges chan *segmentMerge
introducerNotifier chan *epochWatcher introducerNotifier chan *epochWatcher
revertToSnapshots chan *snapshotReversion revertToSnapshots chan *snapshotReversion
persisterNotifier chan notificationChan persisterNotifier chan *epochWatcher
rootBolt *bolt.DB rootBolt *bolt.DB
asyncTasks sync.WaitGroup asyncTasks sync.WaitGroup
@ -114,6 +114,25 @@ func (s *Scorch) fireAsyncError(err error) {
} }
func (s *Scorch) Open() error { func (s *Scorch) Open() error {
err := s.openBolt()
if err != nil {
return err
}
s.asyncTasks.Add(1)
go s.mainLoop()
if !s.readOnly && s.path != "" {
s.asyncTasks.Add(1)
go s.persisterLoop()
s.asyncTasks.Add(1)
go s.mergerLoop()
}
return nil
}
func (s *Scorch) openBolt() error {
var ok bool var ok bool
s.path, ok = s.config["path"].(string) s.path, ok = s.config["path"].(string)
if !ok { if !ok {
@ -136,6 +155,7 @@ func (s *Scorch) Open() error {
} }
} }
} }
rootBoltPath := s.path + string(os.PathSeparator) + "root.bolt" rootBoltPath := s.path + string(os.PathSeparator) + "root.bolt"
var err error var err error
if s.path != "" { if s.path != "" {
@ -156,7 +176,7 @@ func (s *Scorch) Open() error {
s.merges = make(chan *segmentMerge) s.merges = make(chan *segmentMerge)
s.introducerNotifier = make(chan *epochWatcher, 1) s.introducerNotifier = make(chan *epochWatcher, 1)
s.revertToSnapshots = make(chan *snapshotReversion) s.revertToSnapshots = make(chan *snapshotReversion)
s.persisterNotifier = make(chan notificationChan) s.persisterNotifier = make(chan *epochWatcher, 1)
if !s.readOnly && s.path != "" { if !s.readOnly && s.path != "" {
err := s.removeOldZapFiles() // Before persister or merger create any new files. err := s.removeOldZapFiles() // Before persister or merger create any new files.
@ -166,16 +186,6 @@ func (s *Scorch) Open() error {
} }
} }
s.asyncTasks.Add(1)
go s.mainLoop()
if !s.readOnly && s.path != "" {
s.asyncTasks.Add(1)
go s.persisterLoop()
s.asyncTasks.Add(1)
go s.mergerLoop()
}
return nil return nil
} }
@ -310,17 +320,21 @@ func (s *Scorch) prepareSegment(newSegment segment.Segment, ids []string,
introduction.persisted = make(chan error, 1) introduction.persisted = make(chan error, 1)
} }
// get read lock, to optimistically prepare obsoleted info // optimistically prepare obsoletes outside of rootLock
s.rootLock.RLock() s.rootLock.RLock()
for _, seg := range s.root.segment { root := s.root
root.AddRef()
s.rootLock.RUnlock()
for _, seg := range root.segment {
delta, err := seg.segment.DocNumbers(ids) delta, err := seg.segment.DocNumbers(ids)
if err != nil { if err != nil {
s.rootLock.RUnlock()
return err return err
} }
introduction.obsoletes[seg.id] = delta introduction.obsoletes[seg.id] = delta
} }
s.rootLock.RUnlock()
_ = root.DecRef()
s.introductions <- introduction s.introductions <- introduction

View File

@ -95,6 +95,21 @@ func (s *Segment) initializeDict(results []*index.AnalysisResult) {
var numTokenFrequencies int var numTokenFrequencies int
var totLocs int var totLocs int
// initial scan for all fieldID's to sort them
for _, result := range results {
for _, field := range result.Document.CompositeFields {
s.getOrDefineField(field.Name())
}
for _, field := range result.Document.Fields {
s.getOrDefineField(field.Name())
}
}
sort.Strings(s.FieldsInv[1:]) // keep _id as first field
s.FieldsMap = make(map[string]uint16, len(s.FieldsInv))
for fieldID, fieldName := range s.FieldsInv {
s.FieldsMap[fieldName] = uint16(fieldID + 1)
}
processField := func(fieldID uint16, tfs analysis.TokenFrequencies) { processField := func(fieldID uint16, tfs analysis.TokenFrequencies) {
for term, tf := range tfs { for term, tf := range tfs {
pidPlus1, exists := s.Dicts[fieldID][term] pidPlus1, exists := s.Dicts[fieldID][term]

View File

@ -76,6 +76,8 @@ type DictionaryIterator struct {
prefix string prefix string
end string end string
offset int offset int
dictEntry index.DictEntry // reused across Next()'s
} }
// Next returns the next entry in the dictionary // Next returns the next entry in the dictionary
@ -95,8 +97,7 @@ func (d *DictionaryIterator) Next() (*index.DictEntry, error) {
d.offset++ d.offset++
postingID := d.d.segment.Dicts[d.d.fieldID][next] postingID := d.d.segment.Dicts[d.d.fieldID][next]
return &index.DictEntry{ d.dictEntry.Term = next
Term: next, d.dictEntry.Count = d.d.segment.Postings[postingID-1].GetCardinality()
Count: d.d.segment.Postings[postingID-1].GetCardinality(), return &d.dictEntry, nil
}, nil
} }

View File

@ -28,7 +28,7 @@ import (
"github.com/golang/snappy" "github.com/golang/snappy"
) )
const version uint32 = 2 const version uint32 = 3
const fieldNotUninverted = math.MaxUint64 const fieldNotUninverted = math.MaxUint64
@ -187,79 +187,42 @@ func persistBase(memSegment *mem.Segment, cr *CountHashWriter, chunkFactor uint3
} }
func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) { func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error) {
var curr int var curr int
var metaBuf bytes.Buffer var metaBuf bytes.Buffer
var data, compressed []byte var data, compressed []byte
metaEncoder := govarint.NewU64Base128Encoder(&metaBuf)
docNumOffsets := make(map[int]uint64, len(memSegment.Stored)) docNumOffsets := make(map[int]uint64, len(memSegment.Stored))
for docNum, storedValues := range memSegment.Stored { for docNum, storedValues := range memSegment.Stored {
if docNum != 0 { if docNum != 0 {
// reset buffer if necessary // reset buffer if necessary
curr = 0
metaBuf.Reset() metaBuf.Reset()
data = data[:0] data = data[:0]
compressed = compressed[:0] compressed = compressed[:0]
curr = 0
} }
metaEncoder := govarint.NewU64Base128Encoder(&metaBuf)
st := memSegment.StoredTypes[docNum] st := memSegment.StoredTypes[docNum]
sp := memSegment.StoredPos[docNum] sp := memSegment.StoredPos[docNum]
// encode fields in order // encode fields in order
for fieldID := range memSegment.FieldsInv { for fieldID := range memSegment.FieldsInv {
if storedFieldValues, ok := storedValues[uint16(fieldID)]; ok { if storedFieldValues, ok := storedValues[uint16(fieldID)]; ok {
// has stored values for this field
num := len(storedFieldValues)
stf := st[uint16(fieldID)] stf := st[uint16(fieldID)]
spf := sp[uint16(fieldID)] spf := sp[uint16(fieldID)]
// process each value var err2 error
for i := 0; i < num; i++ { curr, data, err2 = persistStoredFieldValues(fieldID,
// encode field storedFieldValues, stf, spf, curr, metaEncoder, data)
_, err2 := metaEncoder.PutU64(uint64(fieldID))
if err2 != nil {
return 0, err2
}
// encode type
_, err2 = metaEncoder.PutU64(uint64(stf[i]))
if err2 != nil {
return 0, err2
}
// encode start offset
_, err2 = metaEncoder.PutU64(uint64(curr))
if err2 != nil {
return 0, err2
}
// end len
_, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
if err2 != nil {
return 0, err2
}
// encode number of array pos
_, err2 = metaEncoder.PutU64(uint64(len(spf[i])))
if err2 != nil {
return 0, err2
}
// encode all array positions
for _, pos := range spf[i] {
_, err2 = metaEncoder.PutU64(pos)
if err2 != nil { if err2 != nil {
return 0, err2 return 0, err2
} }
} }
// append data
data = append(data, storedFieldValues[i]...)
// update curr
curr += len(storedFieldValues[i])
} }
}
}
metaEncoder.Close()
metaEncoder.Close()
metaBytes := metaBuf.Bytes() metaBytes := metaBuf.Bytes()
// compress the data // compress the data
@ -299,6 +262,51 @@ func persistStored(memSegment *mem.Segment, w *CountHashWriter) (uint64, error)
return rv, nil return rv, nil
} }
func persistStoredFieldValues(fieldID int,
storedFieldValues [][]byte, stf []byte, spf [][]uint64,
curr int, metaEncoder *govarint.Base128Encoder, data []byte) (
int, []byte, error) {
for i := 0; i < len(storedFieldValues); i++ {
// encode field
_, err := metaEncoder.PutU64(uint64(fieldID))
if err != nil {
return 0, nil, err
}
// encode type
_, err = metaEncoder.PutU64(uint64(stf[i]))
if err != nil {
return 0, nil, err
}
// encode start offset
_, err = metaEncoder.PutU64(uint64(curr))
if err != nil {
return 0, nil, err
}
// end len
_, err = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
if err != nil {
return 0, nil, err
}
// encode number of array pos
_, err = metaEncoder.PutU64(uint64(len(spf[i])))
if err != nil {
return 0, nil, err
}
// encode all array positions
for _, pos := range spf[i] {
_, err = metaEncoder.PutU64(pos)
if err != nil {
return 0, nil, err
}
}
data = append(data, storedFieldValues[i]...)
curr += len(storedFieldValues[i])
}
return curr, data, nil
}
func persistPostingDetails(memSegment *mem.Segment, w *CountHashWriter, chunkFactor uint32) ([]uint64, []uint64, error) { func persistPostingDetails(memSegment *mem.Segment, w *CountHashWriter, chunkFactor uint32) ([]uint64, []uint64, error) {
var freqOffsets, locOfffsets []uint64 var freqOffsets, locOfffsets []uint64
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), uint64(len(memSegment.Stored)-1)) tfEncoder := newChunkedIntCoder(uint64(chunkFactor), uint64(len(memSegment.Stored)-1))
@ -580,7 +588,7 @@ func persistDocValues(memSegment *mem.Segment, w *CountHashWriter,
if err != nil { if err != nil {
return nil, err return nil, err
} }
// resetting encoder for the next field // reseting encoder for the next field
fdvEncoder.Reset() fdvEncoder.Reset()
} }
@ -625,12 +633,21 @@ func NewSegmentBase(memSegment *mem.Segment, chunkFactor uint32) (*SegmentBase,
return nil, err return nil, err
} }
return InitSegmentBase(br.Bytes(), cr.Sum32(), chunkFactor,
memSegment.FieldsMap, memSegment.FieldsInv, numDocs,
storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs)
}
func InitSegmentBase(mem []byte, memCRC uint32, chunkFactor uint32,
fieldsMap map[string]uint16, fieldsInv []string, numDocs uint64,
storedIndexOffset uint64, fieldsIndexOffset uint64, docValueOffset uint64,
dictLocs []uint64) (*SegmentBase, error) {
sb := &SegmentBase{ sb := &SegmentBase{
mem: br.Bytes(), mem: mem,
memCRC: cr.Sum32(), memCRC: memCRC,
chunkFactor: chunkFactor, chunkFactor: chunkFactor,
fieldsMap: memSegment.FieldsMap, fieldsMap: fieldsMap,
fieldsInv: memSegment.FieldsInv, fieldsInv: fieldsInv,
numDocs: numDocs, numDocs: numDocs,
storedIndexOffset: storedIndexOffset, storedIndexOffset: storedIndexOffset,
fieldsIndexOffset: fieldsIndexOffset, fieldsIndexOffset: fieldsIndexOffset,
@ -639,7 +656,7 @@ func NewSegmentBase(memSegment *mem.Segment, chunkFactor uint32) (*SegmentBase,
fieldDvIterMap: make(map[uint16]*docValueIterator), fieldDvIterMap: make(map[uint16]*docValueIterator),
} }
err = sb.loadDvIterators() err := sb.loadDvIterators()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -39,7 +39,7 @@ type chunkedContentCoder struct {
// MetaData represents the data information inside a // MetaData represents the data information inside a
// chunk. // chunk.
type MetaData struct { type MetaData struct {
DocID uint64 // docid of the data inside the chunk DocNum uint64 // docNum of the data inside the chunk
DocDvLoc uint64 // starting offset for a given docid DocDvLoc uint64 // starting offset for a given docid
DocDvLen uint64 // length of data inside the chunk for the given docid DocDvLen uint64 // length of data inside the chunk for the given docid
} }
@ -52,7 +52,7 @@ func newChunkedContentCoder(chunkSize uint64,
rv := &chunkedContentCoder{ rv := &chunkedContentCoder{
chunkSize: chunkSize, chunkSize: chunkSize,
chunkLens: make([]uint64, total), chunkLens: make([]uint64, total),
chunkMeta: []MetaData{}, chunkMeta: make([]MetaData, 0, total),
} }
return rv return rv
@ -68,7 +68,7 @@ func (c *chunkedContentCoder) Reset() {
for i := range c.chunkLens { for i := range c.chunkLens {
c.chunkLens[i] = 0 c.chunkLens[i] = 0
} }
c.chunkMeta = []MetaData{} c.chunkMeta = c.chunkMeta[:0]
} }
// Close indicates you are done calling Add() this allows // Close indicates you are done calling Add() this allows
@ -88,7 +88,7 @@ func (c *chunkedContentCoder) flushContents() error {
// write out the metaData slice // write out the metaData slice
for _, meta := range c.chunkMeta { for _, meta := range c.chunkMeta {
_, err := writeUvarints(&c.chunkMetaBuf, meta.DocID, meta.DocDvLoc, meta.DocDvLen) _, err := writeUvarints(&c.chunkMetaBuf, meta.DocNum, meta.DocDvLoc, meta.DocDvLen)
if err != nil { if err != nil {
return err return err
} }
@ -118,7 +118,7 @@ func (c *chunkedContentCoder) Add(docNum uint64, vals []byte) error {
// clearing the chunk specific meta for next chunk // clearing the chunk specific meta for next chunk
c.chunkBuf.Reset() c.chunkBuf.Reset()
c.chunkMetaBuf.Reset() c.chunkMetaBuf.Reset()
c.chunkMeta = []MetaData{} c.chunkMeta = c.chunkMeta[:0]
c.currChunk = chunk c.currChunk = chunk
} }
@ -130,7 +130,7 @@ func (c *chunkedContentCoder) Add(docNum uint64, vals []byte) error {
} }
c.chunkMeta = append(c.chunkMeta, MetaData{ c.chunkMeta = append(c.chunkMeta, MetaData{
DocID: docNum, DocNum: docNum,
DocDvLoc: uint64(dvOffset), DocDvLoc: uint64(dvOffset),
DocDvLen: uint64(dvSize), DocDvLen: uint64(dvSize),
}) })

View File

@ -34,32 +34,47 @@ type Dictionary struct {
// PostingsList returns the postings list for the specified term // PostingsList returns the postings list for the specified term
func (d *Dictionary) PostingsList(term string, except *roaring.Bitmap) (segment.PostingsList, error) { func (d *Dictionary) PostingsList(term string, except *roaring.Bitmap) (segment.PostingsList, error) {
return d.postingsList([]byte(term), except) return d.postingsList([]byte(term), except, nil)
} }
func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap) (*PostingsList, error) { func (d *Dictionary) postingsList(term []byte, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) {
rv := &PostingsList{ if d.fst == nil {
sb: d.sb, return d.postingsListInit(rv, except), nil
term: term,
except: except,
} }
if d.fst != nil {
postingsOffset, exists, err := d.fst.Get(term) postingsOffset, exists, err := d.fst.Get(term)
if err != nil { if err != nil {
return nil, fmt.Errorf("vellum err: %v", err) return nil, fmt.Errorf("vellum err: %v", err)
} }
if exists { if !exists {
err = rv.read(postingsOffset, d) return d.postingsListInit(rv, except), nil
}
return d.postingsListFromOffset(postingsOffset, except, rv)
}
func (d *Dictionary) postingsListFromOffset(postingsOffset uint64, except *roaring.Bitmap, rv *PostingsList) (*PostingsList, error) {
rv = d.postingsListInit(rv, except)
err := rv.read(postingsOffset, d)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
}
return rv, nil return rv, nil
} }
func (d *Dictionary) postingsListInit(rv *PostingsList, except *roaring.Bitmap) *PostingsList {
if rv == nil {
rv = &PostingsList{}
} else {
*rv = PostingsList{} // clear the struct
}
rv.sb = d.sb
rv.except = except
return rv
}
// Iterator returns an iterator for this dictionary // Iterator returns an iterator for this dictionary
func (d *Dictionary) Iterator() segment.DictionaryIterator { func (d *Dictionary) Iterator() segment.DictionaryIterator {
rv := &DictionaryIterator{ rv := &DictionaryIterator{

View File

@ -99,7 +99,7 @@ func (s *SegmentBase) loadFieldDocValueIterator(field string,
func (di *docValueIterator) loadDvChunk(chunkNumber, func (di *docValueIterator) loadDvChunk(chunkNumber,
localDocNum uint64, s *SegmentBase) error { localDocNum uint64, s *SegmentBase) error {
// advance to the chunk where the docValues // advance to the chunk where the docValues
// reside for the given docID // reside for the given docNum
destChunkDataLoc := di.dvDataLoc destChunkDataLoc := di.dvDataLoc
for i := 0; i < int(chunkNumber); i++ { for i := 0; i < int(chunkNumber); i++ {
destChunkDataLoc += di.chunkLens[i] destChunkDataLoc += di.chunkLens[i]
@ -116,7 +116,7 @@ func (di *docValueIterator) loadDvChunk(chunkNumber,
offset := uint64(0) offset := uint64(0)
di.curChunkHeader = make([]MetaData, int(numDocs)) di.curChunkHeader = make([]MetaData, int(numDocs))
for i := 0; i < int(numDocs); i++ { for i := 0; i < int(numDocs); i++ {
di.curChunkHeader[i].DocID, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64]) di.curChunkHeader[i].DocNum, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
offset += uint64(read) offset += uint64(read)
di.curChunkHeader[i].DocDvLoc, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64]) di.curChunkHeader[i].DocDvLoc, read = binary.Uvarint(s.mem[chunkMetaLoc+offset : chunkMetaLoc+offset+binary.MaxVarintLen64])
offset += uint64(read) offset += uint64(read)
@ -131,10 +131,10 @@ func (di *docValueIterator) loadDvChunk(chunkNumber,
return nil return nil
} }
func (di *docValueIterator) visitDocValues(docID uint64, func (di *docValueIterator) visitDocValues(docNum uint64,
visitor index.DocumentFieldTermVisitor) error { visitor index.DocumentFieldTermVisitor) error {
// binary search the term locations for the docID // binary search the term locations for the docNum
start, length := di.getDocValueLocs(docID) start, length := di.getDocValueLocs(docNum)
if start == math.MaxUint64 || length == math.MaxUint64 { if start == math.MaxUint64 || length == math.MaxUint64 {
return nil return nil
} }
@ -144,7 +144,7 @@ func (di *docValueIterator) visitDocValues(docID uint64,
return err return err
} }
// pick the terms for the given docID // pick the terms for the given docNum
uncompressed = uncompressed[start : start+length] uncompressed = uncompressed[start : start+length]
for { for {
i := bytes.Index(uncompressed, termSeparatorSplitSlice) i := bytes.Index(uncompressed, termSeparatorSplitSlice)
@ -159,11 +159,11 @@ func (di *docValueIterator) visitDocValues(docID uint64,
return nil return nil
} }
func (di *docValueIterator) getDocValueLocs(docID uint64) (uint64, uint64) { func (di *docValueIterator) getDocValueLocs(docNum uint64) (uint64, uint64) {
i := sort.Search(len(di.curChunkHeader), func(i int) bool { i := sort.Search(len(di.curChunkHeader), func(i int) bool {
return di.curChunkHeader[i].DocID >= docID return di.curChunkHeader[i].DocNum >= docNum
}) })
if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocID == docID { if i < len(di.curChunkHeader) && di.curChunkHeader[i].DocNum == docNum {
return di.curChunkHeader[i].DocDvLoc, di.curChunkHeader[i].DocDvLen return di.curChunkHeader[i].DocDvLoc, di.curChunkHeader[i].DocDvLen
} }
return math.MaxUint64, math.MaxUint64 return math.MaxUint64, math.MaxUint64

View File

@ -0,0 +1,124 @@
// Copyright (c) 2018 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package zap
import (
"bytes"
"github.com/couchbase/vellum"
)
// enumerator provides an ordered traversal of multiple vellum
// iterators. Like JOIN of iterators, the enumerator produces a
// sequence of (key, iteratorIndex, value) tuples, sorted by key ASC,
// then iteratorIndex ASC, where the same key might be seen or
// repeated across multiple child iterators.
type enumerator struct {
itrs []vellum.Iterator
currKs [][]byte
currVs []uint64
lowK []byte
lowIdxs []int
lowCurr int
}
// newEnumerator returns a new enumerator over the vellum Iterators
func newEnumerator(itrs []vellum.Iterator) (*enumerator, error) {
rv := &enumerator{
itrs: itrs,
currKs: make([][]byte, len(itrs)),
currVs: make([]uint64, len(itrs)),
lowIdxs: make([]int, 0, len(itrs)),
}
for i, itr := range rv.itrs {
rv.currKs[i], rv.currVs[i] = itr.Current()
}
rv.updateMatches()
if rv.lowK == nil {
return rv, vellum.ErrIteratorDone
}
return rv, nil
}
// updateMatches maintains the low key matches based on the currKs
func (m *enumerator) updateMatches() {
m.lowK = nil
m.lowIdxs = m.lowIdxs[:0]
m.lowCurr = 0
for i, key := range m.currKs {
if key == nil {
continue
}
cmp := bytes.Compare(key, m.lowK)
if cmp < 0 || m.lowK == nil {
// reached a new low
m.lowK = key
m.lowIdxs = m.lowIdxs[:0]
m.lowIdxs = append(m.lowIdxs, i)
} else if cmp == 0 {
m.lowIdxs = append(m.lowIdxs, i)
}
}
}
// Current returns the enumerator's current key, iterator-index, and
// value. If the enumerator is not pointing at a valid value (because
// Next returned an error previously), Current will return nil,0,0.
func (m *enumerator) Current() ([]byte, int, uint64) {
var i int
var v uint64
if m.lowCurr < len(m.lowIdxs) {
i = m.lowIdxs[m.lowCurr]
v = m.currVs[i]
}
return m.lowK, i, v
}
// Next advances the enumerator to the next key/iterator/value result,
// else vellum.ErrIteratorDone is returned.
func (m *enumerator) Next() error {
m.lowCurr += 1
if m.lowCurr >= len(m.lowIdxs) {
// move all the current low iterators forwards
for _, vi := range m.lowIdxs {
err := m.itrs[vi].Next()
if err != nil && err != vellum.ErrIteratorDone {
return err
}
m.currKs[vi], m.currVs[vi] = m.itrs[vi].Current()
}
m.updateMatches()
}
if m.lowK == nil {
return vellum.ErrIteratorDone
}
return nil
}
// Close all the underlying Iterators. The first error, if any, will
// be returned.
func (m *enumerator) Close() error {
var rv error
for _, itr := range m.itrs {
err := itr.Close()
if rv == nil {
rv = err
}
}
return rv
}

View File

@ -30,6 +30,8 @@ type chunkedIntCoder struct {
encoder *govarint.Base128Encoder encoder *govarint.Base128Encoder
chunkLens []uint64 chunkLens []uint64
currChunk uint64 currChunk uint64
buf []byte
} }
// newChunkedIntCoder returns a new chunk int coder which packs data into // newChunkedIntCoder returns a new chunk int coder which packs data into
@ -67,12 +69,8 @@ func (c *chunkedIntCoder) Add(docNum uint64, vals ...uint64) error {
// starting a new chunk // starting a new chunk
if c.encoder != nil { if c.encoder != nil {
// close out last // close out last
c.encoder.Close() c.Close()
encodingBytes := c.chunkBuf.Bytes()
c.chunkLens[c.currChunk] = uint64(len(encodingBytes))
c.final = append(c.final, encodingBytes...)
c.chunkBuf.Reset() c.chunkBuf.Reset()
c.encoder = govarint.NewU64Base128Encoder(&c.chunkBuf)
} }
c.currChunk = chunk c.currChunk = chunk
} }
@ -98,26 +96,25 @@ func (c *chunkedIntCoder) Close() {
// Write commits all the encoded chunked integers to the provided writer. // Write commits all the encoded chunked integers to the provided writer.
func (c *chunkedIntCoder) Write(w io.Writer) (int, error) { func (c *chunkedIntCoder) Write(w io.Writer) (int, error) {
var tw int bufNeeded := binary.MaxVarintLen64 * (1 + len(c.chunkLens))
buf := make([]byte, binary.MaxVarintLen64) if len(c.buf) < bufNeeded {
// write out the number of chunks c.buf = make([]byte, bufNeeded)
}
buf := c.buf
// write out the number of chunks & each chunkLen
n := binary.PutUvarint(buf, uint64(len(c.chunkLens))) n := binary.PutUvarint(buf, uint64(len(c.chunkLens)))
nw, err := w.Write(buf[:n])
tw += nw
if err != nil {
return tw, err
}
// write out the chunk lens
for _, chunkLen := range c.chunkLens { for _, chunkLen := range c.chunkLens {
n := binary.PutUvarint(buf, uint64(chunkLen)) n += binary.PutUvarint(buf[n:], uint64(chunkLen))
nw, err = w.Write(buf[:n]) }
tw += nw
tw, err := w.Write(buf[:n])
if err != nil { if err != nil {
return tw, err return tw, err
} }
}
// write out the data // write out the data
nw, err = w.Write(c.final) nw, err := w.Write(c.final)
tw += nw tw += nw
if err != nil { if err != nil {
return tw, err return tw, err

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"math" "math"
"os" "os"
"sort"
"github.com/RoaringBitmap/roaring" "github.com/RoaringBitmap/roaring"
"github.com/Smerity/govarint" "github.com/Smerity/govarint"
@ -28,6 +29,8 @@ import (
"github.com/golang/snappy" "github.com/golang/snappy"
) )
const docDropped = math.MaxUint64 // sentinel docNum to represent a deleted doc
// Merge takes a slice of zap segments and bit masks describing which // Merge takes a slice of zap segments and bit masks describing which
// documents may be dropped, and creates a new segment containing the // documents may be dropped, and creates a new segment containing the
// remaining data. This new segment is built at the specified path, // remaining data. This new segment is built at the specified path,
@ -46,47 +49,26 @@ func Merge(segments []*Segment, drops []*roaring.Bitmap, path string,
_ = os.Remove(path) _ = os.Remove(path)
} }
segmentBases := make([]*SegmentBase, len(segments))
for segmenti, segment := range segments {
segmentBases[segmenti] = &segment.SegmentBase
}
// buffer the output // buffer the output
br := bufio.NewWriter(f) br := bufio.NewWriter(f)
// wrap it for counting (tracking offsets) // wrap it for counting (tracking offsets)
cr := NewCountHashWriter(br) cr := NewCountHashWriter(br)
fieldsInv := mergeFields(segments) newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, _, _, _, err :=
fieldsMap := mapFields(fieldsInv) MergeToWriter(segmentBases, drops, chunkFactor, cr)
var newDocNums [][]uint64
var storedIndexOffset uint64
fieldDvLocsOffset := uint64(fieldNotUninverted)
var dictLocs []uint64
newSegDocCount := computeNewDocCount(segments, drops)
if newSegDocCount > 0 {
storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops,
fieldsMap, fieldsInv, newSegDocCount, cr)
if err != nil { if err != nil {
cleanup() cleanup()
return nil, err return nil, err
} }
dictLocs, fieldDvLocsOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap, err = persistFooter(numDocs, storedIndexOffset, fieldsIndexOffset,
newDocNums, newSegDocCount, chunkFactor, cr) docValueOffset, chunkFactor, cr.Sum32(), cr)
if err != nil {
cleanup()
return nil, err
}
} else {
dictLocs = make([]uint64, len(fieldsInv))
}
fieldsIndexOffset, err := persistFields(fieldsInv, cr, dictLocs)
if err != nil {
cleanup()
return nil, err
}
err = persistFooter(newSegDocCount, storedIndexOffset,
fieldsIndexOffset, fieldDvLocsOffset, chunkFactor, cr.Sum32(), cr)
if err != nil { if err != nil {
cleanup() cleanup()
return nil, err return nil, err
@ -113,21 +95,59 @@ func Merge(segments []*Segment, drops []*roaring.Bitmap, path string,
return newDocNums, nil return newDocNums, nil
} }
// mapFields takes the fieldsInv list and builds the map func MergeToWriter(segments []*SegmentBase, drops []*roaring.Bitmap,
chunkFactor uint32, cr *CountHashWriter) (
newDocNums [][]uint64,
numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset uint64,
dictLocs []uint64, fieldsInv []string, fieldsMap map[string]uint16,
err error) {
docValueOffset = uint64(fieldNotUninverted)
var fieldsSame bool
fieldsSame, fieldsInv = mergeFields(segments)
fieldsMap = mapFields(fieldsInv)
numDocs = computeNewDocCount(segments, drops)
if numDocs > 0 {
storedIndexOffset, newDocNums, err = mergeStoredAndRemap(segments, drops,
fieldsMap, fieldsInv, fieldsSame, numDocs, cr)
if err != nil {
return nil, 0, 0, 0, 0, nil, nil, nil, err
}
dictLocs, docValueOffset, err = persistMergedRest(segments, drops, fieldsInv, fieldsMap,
newDocNums, numDocs, chunkFactor, cr)
if err != nil {
return nil, 0, 0, 0, 0, nil, nil, nil, err
}
} else {
dictLocs = make([]uint64, len(fieldsInv))
}
fieldsIndexOffset, err = persistFields(fieldsInv, cr, dictLocs)
if err != nil {
return nil, 0, 0, 0, 0, nil, nil, nil, err
}
return newDocNums, numDocs, storedIndexOffset, fieldsIndexOffset, docValueOffset, dictLocs, fieldsInv, fieldsMap, nil
}
// mapFields takes the fieldsInv list and returns a map of fieldName
// to fieldID+1
func mapFields(fields []string) map[string]uint16 { func mapFields(fields []string) map[string]uint16 {
rv := make(map[string]uint16, len(fields)) rv := make(map[string]uint16, len(fields))
for i, fieldName := range fields { for i, fieldName := range fields {
rv[fieldName] = uint16(i) rv[fieldName] = uint16(i) + 1
} }
return rv return rv
} }
// computeNewDocCount determines how many documents will be in the newly // computeNewDocCount determines how many documents will be in the newly
// merged segment when obsoleted docs are dropped // merged segment when obsoleted docs are dropped
func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 { func computeNewDocCount(segments []*SegmentBase, drops []*roaring.Bitmap) uint64 {
var newDocCount uint64 var newDocCount uint64
for segI, segment := range segments { for segI, segment := range segments {
newDocCount += segment.NumDocs() newDocCount += segment.numDocs
if drops[segI] != nil { if drops[segI] != nil {
newDocCount -= drops[segI].GetCardinality() newDocCount -= drops[segI].GetCardinality()
} }
@ -135,8 +155,8 @@ func computeNewDocCount(segments []*Segment, drops []*roaring.Bitmap) uint64 {
return newDocCount return newDocCount
} }
func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap, func persistMergedRest(segments []*SegmentBase, dropsIn []*roaring.Bitmap,
fieldsInv []string, fieldsMap map[string]uint16, newDocNums [][]uint64, fieldsInv []string, fieldsMap map[string]uint16, newDocNumsIn [][]uint64,
newSegDocCount uint64, chunkFactor uint32, newSegDocCount uint64, chunkFactor uint32,
w *CountHashWriter) ([]uint64, uint64, error) { w *CountHashWriter) ([]uint64, uint64, error) {
@ -144,9 +164,14 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
var bufMaxVarintLen64 []byte = make([]byte, binary.MaxVarintLen64) var bufMaxVarintLen64 []byte = make([]byte, binary.MaxVarintLen64)
var bufLoc []uint64 var bufLoc []uint64
var postings *PostingsList
var postItr *PostingsIterator
rv := make([]uint64, len(fieldsInv)) rv := make([]uint64, len(fieldsInv))
fieldDvLocs := make([]uint64, len(fieldsInv)) fieldDvLocs := make([]uint64, len(fieldsInv))
fieldDvLocsOffset := uint64(fieldNotUninverted)
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
// docTermMap is keyed by docNum, where the array impl provides // docTermMap is keyed by docNum, where the array impl provides
// better memory usage behavior than a sparse-friendlier hashmap // better memory usage behavior than a sparse-friendlier hashmap
@ -166,36 +191,31 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
return nil, 0, err return nil, 0, err
} }
// collect FST iterators from all segments for this field // collect FST iterators from all active segments for this field
var newDocNums [][]uint64
var drops []*roaring.Bitmap
var dicts []*Dictionary var dicts []*Dictionary
var itrs []vellum.Iterator var itrs []vellum.Iterator
for _, segment := range segments {
for segmentI, segment := range segments {
dict, err2 := segment.dictionary(fieldName) dict, err2 := segment.dictionary(fieldName)
if err2 != nil { if err2 != nil {
return nil, 0, err2 return nil, 0, err2
} }
dicts = append(dicts, dict)
if dict != nil && dict.fst != nil { if dict != nil && dict.fst != nil {
itr, err2 := dict.fst.Iterator(nil, nil) itr, err2 := dict.fst.Iterator(nil, nil)
if err2 != nil && err2 != vellum.ErrIteratorDone { if err2 != nil && err2 != vellum.ErrIteratorDone {
return nil, 0, err2 return nil, 0, err2
} }
if itr != nil { if itr != nil {
newDocNums = append(newDocNums, newDocNumsIn[segmentI])
drops = append(drops, dropsIn[segmentI])
dicts = append(dicts, dict)
itrs = append(itrs, itr) itrs = append(itrs, itr)
} }
} }
} }
// create merging iterator
mergeItr, err := vellum.NewMergeIterator(itrs, func(postingOffsets []uint64) uint64 {
// we don't actually use the merged value
return 0
})
tfEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
locEncoder := newChunkedIntCoder(uint64(chunkFactor), newSegDocCount-1)
if uint64(cap(docTermMap)) < newSegDocCount { if uint64(cap(docTermMap)) < newSegDocCount {
docTermMap = make([][]byte, newSegDocCount) docTermMap = make([][]byte, newSegDocCount)
} else { } else {
@ -205,30 +225,103 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
} }
} }
for err == nil { var prevTerm []byte
term, _ := mergeItr.Current()
newRoaring := roaring.NewBitmap() newRoaring := roaring.NewBitmap()
newRoaringLocs := roaring.NewBitmap() newRoaringLocs := roaring.NewBitmap()
finishTerm := func(term []byte) error {
if term == nil {
return nil
}
tfEncoder.Close()
locEncoder.Close()
if newRoaring.GetCardinality() > 0 {
// this field/term actually has hits in the new segment, lets write it down
freqOffset := uint64(w.Count())
_, err := tfEncoder.Write(w)
if err != nil {
return err
}
locOffset := uint64(w.Count())
_, err = locEncoder.Write(w)
if err != nil {
return err
}
postingLocOffset := uint64(w.Count())
_, err = writeRoaringWithLen(newRoaringLocs, w, &bufReuse, bufMaxVarintLen64)
if err != nil {
return err
}
postingOffset := uint64(w.Count())
// write out the start of the term info
n := binary.PutUvarint(bufMaxVarintLen64, freqOffset)
_, err = w.Write(bufMaxVarintLen64[:n])
if err != nil {
return err
}
// write out the start of the loc info
n = binary.PutUvarint(bufMaxVarintLen64, locOffset)
_, err = w.Write(bufMaxVarintLen64[:n])
if err != nil {
return err
}
// write out the start of the posting locs
n = binary.PutUvarint(bufMaxVarintLen64, postingLocOffset)
_, err = w.Write(bufMaxVarintLen64[:n])
if err != nil {
return err
}
_, err = writeRoaringWithLen(newRoaring, w, &bufReuse, bufMaxVarintLen64)
if err != nil {
return err
}
err = newVellum.Insert(term, postingOffset)
if err != nil {
return err
}
}
newRoaring = roaring.NewBitmap()
newRoaringLocs = roaring.NewBitmap()
tfEncoder.Reset() tfEncoder.Reset()
locEncoder.Reset() locEncoder.Reset()
// now go back and get posting list for this term return nil
// but pass in the deleted docs for that segment
for dictI, dict := range dicts {
if dict == nil {
continue
} }
postings, err2 := dict.postingsList(term, drops[dictI])
enumerator, err := newEnumerator(itrs)
for err == nil {
term, itrI, postingsOffset := enumerator.Current()
if !bytes.Equal(prevTerm, term) {
// if the term changed, write out the info collected
// for the previous term
err2 := finishTerm(prevTerm)
if err2 != nil {
return nil, 0, err2
}
}
var err2 error
postings, err2 = dicts[itrI].postingsListFromOffset(
postingsOffset, drops[itrI], postings)
if err2 != nil { if err2 != nil {
return nil, 0, err2 return nil, 0, err2
} }
postItr := postings.Iterator() newDocNumsI := newDocNums[itrI]
postItr = postings.iterator(postItr)
next, err2 := postItr.Next() next, err2 := postItr.Next()
for next != nil && err2 == nil { for next != nil && err2 == nil {
hitNewDocNum := newDocNums[dictI][next.Number()] hitNewDocNum := newDocNumsI[next.Number()]
if hitNewDocNum == docDropped { if hitNewDocNum == docDropped {
return nil, 0, fmt.Errorf("see hit with dropped doc num") return nil, 0, fmt.Errorf("see hit with dropped doc num")
} }
@ -248,7 +341,7 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions())) bufLoc = make([]uint64, 0, 5+len(loc.ArrayPositions()))
} }
args := bufLoc[0:5] args := bufLoc[0:5]
args[0] = uint64(fieldsMap[loc.Field()]) args[0] = uint64(fieldsMap[loc.Field()] - 1)
args[1] = loc.Pos() args[1] = loc.Pos()
args[2] = loc.Start() args[2] = loc.Start()
args[3] = loc.End() args[3] = loc.End()
@ -269,67 +362,21 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
if err2 != nil { if err2 != nil {
return nil, 0, err2 return nil, 0, err2
} }
}
tfEncoder.Close() prevTerm = prevTerm[:0] // copy to prevTerm in case Next() reuses term mem
locEncoder.Close() prevTerm = append(prevTerm, term...)
if newRoaring.GetCardinality() > 0 { err = enumerator.Next()
// this field/term actually has hits in the new segment, lets write it down
freqOffset := uint64(w.Count())
_, err = tfEncoder.Write(w)
if err != nil {
return nil, 0, err
}
locOffset := uint64(w.Count())
_, err = locEncoder.Write(w)
if err != nil {
return nil, 0, err
}
postingLocOffset := uint64(w.Count())
_, err = writeRoaringWithLen(newRoaringLocs, w, &bufReuse, bufMaxVarintLen64)
if err != nil {
return nil, 0, err
}
postingOffset := uint64(w.Count())
// write out the start of the term info
buf := bufMaxVarintLen64
n := binary.PutUvarint(buf, freqOffset)
_, err = w.Write(buf[:n])
if err != nil {
return nil, 0, err
}
// write out the start of the loc info
n = binary.PutUvarint(buf, locOffset)
_, err = w.Write(buf[:n])
if err != nil {
return nil, 0, err
}
// write out the start of the loc posting list
n = binary.PutUvarint(buf, postingLocOffset)
_, err = w.Write(buf[:n])
if err != nil {
return nil, 0, err
}
_, err = writeRoaringWithLen(newRoaring, w, &bufReuse, bufMaxVarintLen64)
if err != nil {
return nil, 0, err
}
err = newVellum.Insert(term, postingOffset)
if err != nil {
return nil, 0, err
}
}
err = mergeItr.Next()
} }
if err != nil && err != vellum.ErrIteratorDone { if err != nil && err != vellum.ErrIteratorDone {
return nil, 0, err return nil, 0, err
} }
err = finishTerm(prevTerm)
if err != nil {
return nil, 0, err
}
dictOffset := uint64(w.Count()) dictOffset := uint64(w.Count())
err = newVellum.Close() err = newVellum.Close()
@ -378,7 +425,7 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
} }
} }
fieldDvLocsOffset = uint64(w.Count()) fieldDvLocsOffset := uint64(w.Count())
buf := bufMaxVarintLen64 buf := bufMaxVarintLen64
for _, offset := range fieldDvLocs { for _, offset := range fieldDvLocs {
@ -392,10 +439,8 @@ func persistMergedRest(segments []*Segment, drops []*roaring.Bitmap,
return rv, fieldDvLocsOffset, nil return rv, fieldDvLocsOffset, nil
} }
const docDropped = math.MaxUint64 func mergeStoredAndRemap(segments []*SegmentBase, drops []*roaring.Bitmap,
fieldsMap map[string]uint16, fieldsInv []string, fieldsSame bool, newSegDocCount uint64,
func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
fieldsMap map[string]uint16, fieldsInv []string, newSegDocCount uint64,
w *CountHashWriter) (uint64, [][]uint64, error) { w *CountHashWriter) (uint64, [][]uint64, error) {
var rv [][]uint64 // The remapped or newDocNums for each segment. var rv [][]uint64 // The remapped or newDocNums for each segment.
@ -417,10 +462,30 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
for segI, segment := range segments { for segI, segment := range segments {
segNewDocNums := make([]uint64, segment.numDocs) segNewDocNums := make([]uint64, segment.numDocs)
dropsI := drops[segI]
// optimize when the field mapping is the same across all
// segments and there are no deletions, via byte-copying
// of stored docs bytes directly to the writer
if fieldsSame && (dropsI == nil || dropsI.GetCardinality() == 0) {
err := segment.copyStoredDocs(newDocNum, docNumOffsets, w)
if err != nil {
return 0, nil, err
}
for i := uint64(0); i < segment.numDocs; i++ {
segNewDocNums[i] = newDocNum
newDocNum++
}
rv = append(rv, segNewDocNums)
continue
}
// for each doc num // for each doc num
for docNum := uint64(0); docNum < segment.numDocs; docNum++ { for docNum := uint64(0); docNum < segment.numDocs; docNum++ {
// TODO: roaring's API limits docNums to 32-bits? // TODO: roaring's API limits docNums to 32-bits?
if drops[segI] != nil && drops[segI].Contains(uint32(docNum)) { if dropsI != nil && dropsI.Contains(uint32(docNum)) {
segNewDocNums[docNum] = docDropped segNewDocNums[docNum] = docDropped
continue continue
} }
@ -439,7 +504,7 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
poss[i] = poss[i][:0] poss[i] = poss[i][:0]
} }
err := segment.VisitDocument(docNum, func(field string, typ byte, value []byte, pos []uint64) bool { err := segment.VisitDocument(docNum, func(field string, typ byte, value []byte, pos []uint64) bool {
fieldID := int(fieldsMap[field]) fieldID := int(fieldsMap[field]) - 1
vals[fieldID] = append(vals[fieldID], value) vals[fieldID] = append(vals[fieldID], value)
typs[fieldID] = append(typs[fieldID], typ) typs[fieldID] = append(typs[fieldID], typ)
poss[fieldID] = append(poss[fieldID], pos) poss[fieldID] = append(poss[fieldID], pos)
@ -453,48 +518,15 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
for fieldID := range fieldsInv { for fieldID := range fieldsInv {
storedFieldValues := vals[int(fieldID)] storedFieldValues := vals[int(fieldID)]
// has stored values for this field stf := typs[int(fieldID)]
num := len(storedFieldValues) spf := poss[int(fieldID)]
// process each value var err2 error
for i := 0; i < num; i++ { curr, data, err2 = persistStoredFieldValues(fieldID,
// encode field storedFieldValues, stf, spf, curr, metaEncoder, data)
_, err2 := metaEncoder.PutU64(uint64(fieldID))
if err2 != nil { if err2 != nil {
return 0, nil, err2 return 0, nil, err2
} }
// encode type
_, err2 = metaEncoder.PutU64(uint64(typs[int(fieldID)][i]))
if err2 != nil {
return 0, nil, err2
}
// encode start offset
_, err2 = metaEncoder.PutU64(uint64(curr))
if err2 != nil {
return 0, nil, err2
}
// end len
_, err2 = metaEncoder.PutU64(uint64(len(storedFieldValues[i])))
if err2 != nil {
return 0, nil, err2
}
// encode number of array pos
_, err2 = metaEncoder.PutU64(uint64(len(poss[int(fieldID)][i])))
if err2 != nil {
return 0, nil, err2
}
// encode all array positions
for j := 0; j < len(poss[int(fieldID)][i]); j++ {
_, err2 = metaEncoder.PutU64(poss[int(fieldID)][i][j])
if err2 != nil {
return 0, nil, err2
}
}
// append data
data = append(data, storedFieldValues[i]...)
// update curr
curr += len(storedFieldValues[i])
}
} }
metaEncoder.Close() metaEncoder.Close()
@ -528,36 +560,87 @@ func mergeStoredAndRemap(segments []*Segment, drops []*roaring.Bitmap,
} }
// return value is the start of the stored index // return value is the start of the stored index
offset := uint64(w.Count()) storedIndexOffset := uint64(w.Count())
// now write out the stored doc index // now write out the stored doc index
for docNum := range docNumOffsets { for _, docNumOffset := range docNumOffsets {
err := binary.Write(w, binary.BigEndian, docNumOffsets[docNum]) err := binary.Write(w, binary.BigEndian, docNumOffset)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, err
} }
} }
return offset, rv, nil return storedIndexOffset, rv, nil
} }
// mergeFields builds a unified list of fields used across all the input segments // copyStoredDocs writes out a segment's stored doc info, optimized by
func mergeFields(segments []*Segment) []string { // using a single Write() call for the entire set of bytes. The
fieldsMap := map[string]struct{}{} // newDocNumOffsets is filled with the new offsets for each doc.
func (s *SegmentBase) copyStoredDocs(newDocNum uint64, newDocNumOffsets []uint64,
w *CountHashWriter) error {
if s.numDocs <= 0 {
return nil
}
indexOffset0, storedOffset0, _, _, _ :=
s.getDocStoredOffsets(0) // the segment's first doc
indexOffsetN, storedOffsetN, readN, metaLenN, dataLenN :=
s.getDocStoredOffsets(s.numDocs - 1) // the segment's last doc
storedOffset0New := uint64(w.Count())
storedBytes := s.mem[storedOffset0 : storedOffsetN+readN+metaLenN+dataLenN]
_, err := w.Write(storedBytes)
if err != nil {
return err
}
// remap the storedOffset's for the docs into new offsets relative
// to storedOffset0New, filling the given docNumOffsetsOut array
for indexOffset := indexOffset0; indexOffset <= indexOffsetN; indexOffset += 8 {
storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8])
storedOffsetNew := storedOffset - storedOffset0 + storedOffset0New
newDocNumOffsets[newDocNum] = storedOffsetNew
newDocNum += 1
}
return nil
}
// mergeFields builds a unified list of fields used across all the
// input segments, and computes whether the fields are the same across
// segments (which depends on fields to be sorted in the same way
// across segments)
func mergeFields(segments []*SegmentBase) (bool, []string) {
fieldsSame := true
var segment0Fields []string
if len(segments) > 0 {
segment0Fields = segments[0].Fields()
}
fieldsExist := map[string]struct{}{}
for _, segment := range segments { for _, segment := range segments {
fields := segment.Fields() fields := segment.Fields()
for _, field := range fields { for fieldi, field := range fields {
fieldsMap[field] = struct{}{} fieldsExist[field] = struct{}{}
if len(segment0Fields) != len(fields) || segment0Fields[fieldi] != field {
fieldsSame = false
}
} }
} }
rv := make([]string, 0, len(fieldsMap)) rv := make([]string, 0, len(fieldsExist))
// ensure _id stays first // ensure _id stays first
rv = append(rv, "_id") rv = append(rv, "_id")
for k := range fieldsMap { for k := range fieldsExist {
if k != "_id" { if k != "_id" {
rv = append(rv, k) rv = append(rv, k)
} }
} }
return rv
sort.Strings(rv[1:]) // leave _id as first
return fieldsSame, rv
} }

View File

@ -28,21 +28,27 @@ import (
// PostingsList is an in-memory represenation of a postings list // PostingsList is an in-memory represenation of a postings list
type PostingsList struct { type PostingsList struct {
sb *SegmentBase sb *SegmentBase
term []byte
postingsOffset uint64 postingsOffset uint64
freqOffset uint64 freqOffset uint64
locOffset uint64 locOffset uint64
locBitmap *roaring.Bitmap locBitmap *roaring.Bitmap
postings *roaring.Bitmap postings *roaring.Bitmap
except *roaring.Bitmap except *roaring.Bitmap
postingKey []byte
} }
// Iterator returns an iterator for this postings list // Iterator returns an iterator for this postings list
func (p *PostingsList) Iterator() segment.PostingsIterator { func (p *PostingsList) Iterator() segment.PostingsIterator {
rv := &PostingsIterator{ return p.iterator(nil)
postings: p,
} }
func (p *PostingsList) iterator(rv *PostingsIterator) *PostingsIterator {
if rv == nil {
rv = &PostingsIterator{}
} else {
*rv = PostingsIterator{} // clear the struct
}
rv.postings = p
if p.postings != nil { if p.postings != nil {
// prepare the freq chunk details // prepare the freq chunk details
var n uint64 var n uint64

View File

@ -17,15 +17,27 @@ package zap
import "encoding/binary" import "encoding/binary"
func (s *SegmentBase) getDocStoredMetaAndCompressed(docNum uint64) ([]byte, []byte) { func (s *SegmentBase) getDocStoredMetaAndCompressed(docNum uint64) ([]byte, []byte) {
docStoredStartAddr := s.storedIndexOffset + (8 * docNum) _, storedOffset, n, metaLen, dataLen := s.getDocStoredOffsets(docNum)
docStoredStart := binary.BigEndian.Uint64(s.mem[docStoredStartAddr : docStoredStartAddr+8])
var n uint64 meta := s.mem[storedOffset+n : storedOffset+n+metaLen]
metaLen, read := binary.Uvarint(s.mem[docStoredStart : docStoredStart+binary.MaxVarintLen64]) data := s.mem[storedOffset+n+metaLen : storedOffset+n+metaLen+dataLen]
n += uint64(read)
var dataLen uint64
dataLen, read = binary.Uvarint(s.mem[docStoredStart+n : docStoredStart+n+binary.MaxVarintLen64])
n += uint64(read)
meta := s.mem[docStoredStart+n : docStoredStart+n+metaLen]
data := s.mem[docStoredStart+n+metaLen : docStoredStart+n+metaLen+dataLen]
return meta, data return meta, data
} }
func (s *SegmentBase) getDocStoredOffsets(docNum uint64) (
uint64, uint64, uint64, uint64, uint64) {
indexOffset := s.storedIndexOffset + (8 * docNum)
storedOffset := binary.BigEndian.Uint64(s.mem[indexOffset : indexOffset+8])
var n uint64
metaLen, read := binary.Uvarint(s.mem[storedOffset : storedOffset+binary.MaxVarintLen64])
n += uint64(read)
dataLen, read := binary.Uvarint(s.mem[storedOffset+n : storedOffset+n+binary.MaxVarintLen64])
n += uint64(read)
return indexOffset, storedOffset, n, metaLen, dataLen
}

View File

@ -343,8 +343,9 @@ func (s *SegmentBase) DocNumbers(ids []string) (*roaring.Bitmap, error) {
return nil, err return nil, err
} }
var postings *PostingsList
for _, id := range ids { for _, id := range ids {
postings, err := idDict.postingsList([]byte(id), nil) postings, err = idDict.postingsList([]byte(id), nil, postings)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -31,10 +31,9 @@ func (r *RollbackPoint) GetInternal(key []byte) []byte {
return r.meta[string(key)] return r.meta[string(key)]
} }
// RollbackPoints returns an array of rollback points available // RollbackPoints returns an array of rollback points available for
// for the application to make a decision on where to rollback // the application to rollback to, with more recent rollback points
// to. A nil return value indicates that there are no available // (higher epochs) coming first.
// rollback points.
func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) { func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) {
if s.rootBolt == nil { if s.rootBolt == nil {
return nil, fmt.Errorf("RollbackPoints: root is nil") return nil, fmt.Errorf("RollbackPoints: root is nil")
@ -54,7 +53,7 @@ func (s *Scorch) RollbackPoints() ([]*RollbackPoint, error) {
snapshots := tx.Bucket(boltSnapshotsBucket) snapshots := tx.Bucket(boltSnapshotsBucket)
if snapshots == nil { if snapshots == nil {
return nil, fmt.Errorf("RollbackPoints: no snapshots available") return nil, nil
} }
rollbackPoints := []*RollbackPoint{} rollbackPoints := []*RollbackPoint{}
@ -150,10 +149,7 @@ func (s *Scorch) Rollback(to *RollbackPoint) error {
revert.snapshot = indexSnapshot revert.snapshot = indexSnapshot
revert.applied = make(chan error) revert.applied = make(chan error)
if !s.unsafeBatch {
revert.persisted = make(chan error) revert.persisted = make(chan error)
}
return nil return nil
}) })
@ -173,9 +169,5 @@ func (s *Scorch) Rollback(to *RollbackPoint) error {
return fmt.Errorf("Rollback: failed with err: %v", err) return fmt.Errorf("Rollback: failed with err: %v", err)
} }
if revert.persisted != nil { return <-revert.persisted
err = <-revert.persisted
}
return err
} }

View File

@ -837,6 +837,11 @@ func (udc *UpsideDownCouch) Batch(batch *index.Batch) (err error) {
docBackIndexRowErr = err docBackIndexRowErr = err
return return
} }
defer func() {
if cerr := kvreader.Close(); err == nil && cerr != nil {
docBackIndexRowErr = cerr
}
}()
for docID, doc := range batch.IndexOps { for docID, doc := range batch.IndexOps {
backIndexRow, err := backIndexRowForDoc(kvreader, index.IndexInternalID(docID)) backIndexRow, err := backIndexRowForDoc(kvreader, index.IndexInternalID(docID))
@ -847,12 +852,6 @@ func (udc *UpsideDownCouch) Batch(batch *index.Batch) (err error) {
docBackIndexRowCh <- &docBackIndexRow{docID, doc, backIndexRow} docBackIndexRowCh <- &docBackIndexRow{docID, doc, backIndexRow}
} }
err = kvreader.Close()
if err != nil {
docBackIndexRowErr = err
return
}
}() }()
// wait for analysis result // wait for analysis result

View File

@ -15,12 +15,11 @@
package bleve package bleve
import ( import (
"context"
"sort" "sort"
"sync" "sync"
"time" "time"
"golang.org/x/net/context"
"github.com/blevesearch/bleve/document" "github.com/blevesearch/bleve/document"
"github.com/blevesearch/bleve/index" "github.com/blevesearch/bleve/index"
"github.com/blevesearch/bleve/index/store" "github.com/blevesearch/bleve/index/store"

View File

@ -15,6 +15,7 @@
package bleve package bleve
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -22,8 +23,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/net/context"
"github.com/blevesearch/bleve/document" "github.com/blevesearch/bleve/document"
"github.com/blevesearch/bleve/index" "github.com/blevesearch/bleve/index"
"github.com/blevesearch/bleve/index/store" "github.com/blevesearch/bleve/index/store"

View File

@ -15,11 +15,10 @@
package search package search
import ( import (
"context"
"time" "time"
"github.com/blevesearch/bleve/index" "github.com/blevesearch/bleve/index"
"golang.org/x/net/context"
) )
type Collector interface { type Collector interface {

View File

@ -15,11 +15,11 @@
package collector package collector
import ( import (
"context"
"time" "time"
"github.com/blevesearch/bleve/index" "github.com/blevesearch/bleve/index"
"github.com/blevesearch/bleve/search" "github.com/blevesearch/bleve/search"
"golang.org/x/net/context"
) )
type collectorStore interface { type collectorStore interface {

View File

@ -1,10 +1,10 @@
package goth package goth
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"golang.org/x/net/context"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )

View File

@ -4,17 +4,18 @@ package facebook
import ( import (
"bytes" "bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/markbates/goth" "github.com/markbates/goth"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -22,7 +23,7 @@ import (
const ( const (
authURL string = "https://www.facebook.com/dialog/oauth" authURL string = "https://www.facebook.com/dialog/oauth"
tokenURL string = "https://graph.facebook.com/oauth/access_token" tokenURL string = "https://graph.facebook.com/oauth/access_token"
endpointProfile string = "https://graph.facebook.com/me?fields=email,first_name,last_name,link,about,id,name,picture,location" endpointProfile string = "https://graph.facebook.com/me?fields="
) )
// New creates a new Facebook provider, and sets up important connection details. // New creates a new Facebook provider, and sets up important connection details.
@ -68,9 +69,9 @@ func (p *Provider) Debug(debug bool) {}
// BeginAuth asks Facebook for an authentication end-point. // BeginAuth asks Facebook for an authentication end-point.
func (p *Provider) BeginAuth(state string) (goth.Session, error) { func (p *Provider) BeginAuth(state string) (goth.Session, error) {
url := p.config.AuthCodeURL(state) authUrl := p.config.AuthCodeURL(state)
session := &Session{ session := &Session{
AuthURL: url, AuthURL: authUrl,
} }
return session, nil return session, nil
} }
@ -96,7 +97,15 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
hash.Write([]byte(sess.AccessToken)) hash.Write([]byte(sess.AccessToken))
appsecretProof := hex.EncodeToString(hash.Sum(nil)) appsecretProof := hex.EncodeToString(hash.Sum(nil))
response, err := p.Client().Get(endpointProfile + "&access_token=" + url.QueryEscape(sess.AccessToken) + "&appsecret_proof=" + appsecretProof) reqUrl := fmt.Sprint(
endpointProfile,
strings.Join(p.config.Scopes, ","),
"&access_token=",
url.QueryEscape(sess.AccessToken),
"&appsecret_proof=",
appsecretProof,
)
response, err := p.Client().Get(reqUrl)
if err != nil { if err != nil {
return user, err return user, err
} }
@ -168,17 +177,31 @@ func newConfig(provider *Provider, scopes []string) *oauth2.Config {
}, },
Scopes: []string{ Scopes: []string{
"email", "email",
"first_name",
"last_name",
"link",
"about",
"id",
"name",
"picture",
"location",
}, },
} }
defaultScopes := map[string]struct{}{ // creates possibility to invoke field method like 'picture.type(large)'
"email": {}, var found bool
for _, sc := range scopes {
sc := sc
for i, defScope := range c.Scopes {
if defScope == strings.Split(sc, ".")[0] {
c.Scopes[i] = sc
found = true
} }
for _, scope := range scopes {
if _, exists := defaultScopes[scope]; !exists {
c.Scopes = append(c.Scopes, scope)
} }
if !found {
c.Scopes = append(c.Scopes, sc)
}
found = false
} }
return c return c

74
vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go generated vendored Normal file
View File

@ -0,0 +1,74 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.7
// Package ctxhttp provides helper functions for performing context-aware HTTP requests.
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
import (
"io"
"net/http"
"net/url"
"strings"
"golang.org/x/net/context"
)
// Do sends an HTTP request with the provided http.Client and returns
// an HTTP response.
//
// If the client is nil, http.DefaultClient is used.
//
// The provided ctx must be non-nil. If it is canceled or times out,
// ctx.Err() will be returned.
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req.WithContext(ctx))
// If we got an error, and the context has been canceled,
// the context's error is probably more useful.
if err != nil {
select {
case <-ctx.Done():
err = ctx.Err()
default:
}
}
return resp, err
}
// Get issues a GET request via the Do function.
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
return Do(ctx, client, req)
}
// Head issues a HEAD request via the Do function.
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
return Do(ctx, client, req)
}
// Post issues a POST request via the Do function.
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", bodyType)
return Do(ctx, client, req)
}
// PostForm issues a POST request via the Do function.
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}

View File

@ -0,0 +1,147 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.7
package ctxhttp // import "golang.org/x/net/context/ctxhttp"
import (
"io"
"net/http"
"net/url"
"strings"
"golang.org/x/net/context"
)
func nop() {}
var (
testHookContextDoneBeforeHeaders = nop
testHookDoReturned = nop
testHookDidBodyClose = nop
)
// Do sends an HTTP request with the provided http.Client and returns an HTTP response.
// If the client is nil, http.DefaultClient is used.
// If the context is canceled or times out, ctx.Err() will be returned.
func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
if client == nil {
client = http.DefaultClient
}
// TODO(djd): Respect any existing value of req.Cancel.
cancel := make(chan struct{})
req.Cancel = cancel
type responseAndError struct {
resp *http.Response
err error
}
result := make(chan responseAndError, 1)
// Make local copies of test hooks closed over by goroutines below.
// Prevents data races in tests.
testHookDoReturned := testHookDoReturned
testHookDidBodyClose := testHookDidBodyClose
go func() {
resp, err := client.Do(req)
testHookDoReturned()
result <- responseAndError{resp, err}
}()
var resp *http.Response
select {
case <-ctx.Done():
testHookContextDoneBeforeHeaders()
close(cancel)
// Clean up after the goroutine calling client.Do:
go func() {
if r := <-result; r.resp != nil {
testHookDidBodyClose()
r.resp.Body.Close()
}
}()
return nil, ctx.Err()
case r := <-result:
var err error
resp, err = r.resp, r.err
if err != nil {
return resp, err
}
}
c := make(chan struct{})
go func() {
select {
case <-ctx.Done():
close(cancel)
case <-c:
// The response's Body is closed.
}
}()
resp.Body = &notifyingReader{resp.Body, c}
return resp, nil
}
// Get issues a GET request via the Do function.
func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
return Do(ctx, client, req)
}
// Head issues a HEAD request via the Do function.
func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) {
req, err := http.NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
return Do(ctx, client, req)
}
// Post issues a POST request via the Do function.
func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", bodyType)
return Do(ctx, client, req)
}
// PostForm issues a POST request via the Do function.
func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) {
return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
// notifyingReader is an io.ReadCloser that closes the notify channel after
// Close is called or a Read fails on the underlying ReadCloser.
type notifyingReader struct {
io.ReadCloser
notify chan<- struct{}
}
func (r *notifyingReader) Read(p []byte) (int, error) {
n, err := r.ReadCloser.Read(p)
if err != nil && r.notify != nil {
close(r.notify)
r.notify = nil
}
return n, err
}
func (r *notifyingReader) Close() error {
err := r.ReadCloser.Close()
if r.notify != nil {
close(r.notify)
r.notify = nil
}
return err
}

2
vendor/golang.org/x/oauth2/LICENSE generated vendored
View File

@ -1,4 +1,4 @@
Copyright (c) 2009 The oauth2 Authors. All rights reserved. Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are modification, are permitted provided that the following conditions are

View File

@ -1,25 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine
// App Engine hooks.
package oauth2
import (
"net/http"
"golang.org/x/net/context"
"golang.org/x/oauth2/internal"
"google.golang.org/appengine/urlfetch"
)
func init() {
internal.RegisterContextClientFunc(contextClientAppEngine)
}
func contextClientAppEngine(ctx context.Context) (*http.Client, error) {
return urlfetch.Client(ctx), nil
}

View File

@ -0,0 +1,13 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine
package internal
import "google.golang.org/appengine/urlfetch"
func init() {
appengineClientHook = urlfetch.Client
}

6
vendor/golang.org/x/oauth2/internal/doc.go generated vendored Normal file
View File

@ -0,0 +1,6 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package internal contains support packages for oauth2 package.
package internal

View File

@ -2,18 +2,14 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package internal contains support packages for oauth2 package.
package internal package internal
import ( import (
"bufio"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"io"
"strings"
) )
// ParseKey converts the binary contents of a private key file // ParseKey converts the binary contents of a private key file
@ -30,7 +26,7 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
if err != nil { if err != nil {
parsedKey, err = x509.ParsePKCS1PrivateKey(key) parsedKey, err = x509.ParsePKCS1PrivateKey(key)
if err != nil { if err != nil {
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err) return nil, fmt.Errorf("private key should be a PEM or plain PKCS1 or PKCS8; parse error: %v", err)
} }
} }
parsed, ok := parsedKey.(*rsa.PrivateKey) parsed, ok := parsedKey.(*rsa.PrivateKey)
@ -39,38 +35,3 @@ func ParseKey(key []byte) (*rsa.PrivateKey, error) {
} }
return parsed, nil return parsed, nil
} }
func ParseINI(ini io.Reader) (map[string]map[string]string, error) {
result := map[string]map[string]string{
"": map[string]string{}, // root section
}
scanner := bufio.NewScanner(ini)
currentSection := ""
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, ";") {
// comment.
continue
}
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
currentSection = strings.TrimSpace(line[1 : len(line)-1])
result[currentSection] = map[string]string{}
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) == 2 && parts[0] != "" {
result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error scanning ini: %v", err)
}
return result, nil
}
func CondVal(v string) []string {
if v == "" {
return nil
}
return []string{v}
}

View File

@ -2,11 +2,12 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package internal contains support packages for oauth2 package.
package internal package internal
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -17,10 +18,10 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/net/context" "golang.org/x/net/context/ctxhttp"
) )
// Token represents the crendentials used to authorize // Token represents the credentials used to authorize
// the requests to access protected resources on the OAuth 2.0 // the requests to access protected resources on the OAuth 2.0
// provider's backend. // provider's backend.
// //
@ -91,6 +92,7 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
var brokenAuthHeaderProviders = []string{ var brokenAuthHeaderProviders = []string{
"https://accounts.google.com/", "https://accounts.google.com/",
"https://api.codeswholesale.com/oauth/token",
"https://api.dropbox.com/", "https://api.dropbox.com/",
"https://api.dropboxapi.com/", "https://api.dropboxapi.com/",
"https://api.instagram.com/", "https://api.instagram.com/",
@ -99,10 +101,16 @@ var brokenAuthHeaderProviders = []string{
"https://api.pushbullet.com/", "https://api.pushbullet.com/",
"https://api.soundcloud.com/", "https://api.soundcloud.com/",
"https://api.twitch.tv/", "https://api.twitch.tv/",
"https://id.twitch.tv/",
"https://app.box.com/", "https://app.box.com/",
"https://api.box.com/",
"https://connect.stripe.com/", "https://connect.stripe.com/",
"https://login.mailchimp.com/",
"https://login.microsoftonline.com/", "https://login.microsoftonline.com/",
"https://login.salesforce.com/", "https://login.salesforce.com/",
"https://login.windows.net",
"https://login.live.com/",
"https://login.live-int.com/",
"https://oauth.sandbox.trainingpeaks.com/", "https://oauth.sandbox.trainingpeaks.com/",
"https://oauth.trainingpeaks.com/", "https://oauth.trainingpeaks.com/",
"https://oauth.vk.com/", "https://oauth.vk.com/",
@ -117,6 +125,24 @@ var brokenAuthHeaderProviders = []string{
"https://www.strava.com/oauth/", "https://www.strava.com/oauth/",
"https://www.wunderlist.com/oauth/", "https://www.wunderlist.com/oauth/",
"https://api.patreon.com/", "https://api.patreon.com/",
"https://sandbox.codeswholesale.com/oauth/token",
"https://api.sipgate.com/v1/authorization/oauth",
"https://api.medium.com/v1/tokens",
"https://log.finalsurge.com/oauth/token",
"https://multisport.todaysplan.com.au/rest/oauth/access_token",
"https://whats.todaysplan.com.au/rest/oauth/access_token",
"https://stackoverflow.com/oauth/access_token",
"https://account.health.nokia.com",
"https://accounts.zoho.com",
}
// brokenAuthHeaderDomains lists broken providers that issue dynamic endpoints.
var brokenAuthHeaderDomains = []string{
".auth0.com",
".force.com",
".myshopify.com",
".okta.com",
".oktapreview.com",
} }
func RegisterBrokenAuthHeaderProvider(tokenURL string) { func RegisterBrokenAuthHeaderProvider(tokenURL string) {
@ -139,6 +165,14 @@ func providerAuthHeaderWorks(tokenURL string) bool {
} }
} }
if u, err := url.Parse(tokenURL); err == nil {
for _, s := range brokenAuthHeaderDomains {
if strings.HasSuffix(u.Host, s) {
return false
}
}
}
// Assume the provider implements the spec properly // Assume the provider implements the spec properly
// otherwise. We can add more exceptions as they're // otherwise. We can add more exceptions as they're
// discovered. We will _not_ be adding configurable hooks // discovered. We will _not_ be adding configurable hooks
@ -147,24 +181,24 @@ func providerAuthHeaderWorks(tokenURL string) bool {
} }
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) { func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
hc, err := ContextClient(ctx)
if err != nil {
return nil, err
}
v.Set("client_id", clientID)
bustedAuth := !providerAuthHeaderWorks(tokenURL) bustedAuth := !providerAuthHeaderWorks(tokenURL)
if bustedAuth && clientSecret != "" { if bustedAuth {
if clientID != "" {
v.Set("client_id", clientID)
}
if clientSecret != "" {
v.Set("client_secret", clientSecret) v.Set("client_secret", clientSecret)
} }
}
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode())) req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if !bustedAuth { if !bustedAuth {
req.SetBasicAuth(clientID, clientSecret) req.SetBasicAuth(url.QueryEscape(clientID), url.QueryEscape(clientSecret))
} }
r, err := hc.Do(req) r, err := ctxhttp.Do(ctx, ContextClient(ctx), req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -174,7 +208,10 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
if code := r.StatusCode; code < 200 || code > 299 { if code := r.StatusCode; code < 200 || code > 299 {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body) return nil, &RetrieveError{
Response: r,
Body: body,
}
} }
var token *Token var token *Token
@ -221,5 +258,17 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
if token.RefreshToken == "" { if token.RefreshToken == "" {
token.RefreshToken = v.Get("refresh_token") token.RefreshToken = v.Get("refresh_token")
} }
if token.AccessToken == "" {
return token, errors.New("oauth2: server response missing access_token")
}
return token, nil return token, nil
} }
type RetrieveError struct {
Response *http.Response
Body []byte
}
func (r *RetrieveError) Error() string {
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
}

View File

@ -2,13 +2,11 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package internal contains support packages for oauth2 package.
package internal package internal
import ( import (
"context"
"net/http" "net/http"
"golang.org/x/net/context"
) )
// HTTPClient is the context key to use with golang.org/x/net/context's // HTTPClient is the context key to use with golang.org/x/net/context's
@ -20,50 +18,16 @@ var HTTPClient ContextKey
// because nobody else can create a ContextKey, being unexported. // because nobody else can create a ContextKey, being unexported.
type ContextKey struct{} type ContextKey struct{}
// ContextClientFunc is a func which tries to return an *http.Client var appengineClientHook func(context.Context) *http.Client
// given a Context value. If it returns an error, the search stops
// with that error. If it returns (nil, nil), the search continues
// down the list of registered funcs.
type ContextClientFunc func(context.Context) (*http.Client, error)
var contextClientFuncs []ContextClientFunc func ContextClient(ctx context.Context) *http.Client {
func RegisterContextClientFunc(fn ContextClientFunc) {
contextClientFuncs = append(contextClientFuncs, fn)
}
func ContextClient(ctx context.Context) (*http.Client, error) {
if ctx != nil { if ctx != nil {
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok { if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
return hc, nil return hc
} }
} }
for _, fn := range contextClientFuncs { if appengineClientHook != nil {
c, err := fn(ctx) return appengineClientHook(ctx)
if err != nil {
return nil, err
} }
if c != nil { return http.DefaultClient
return c, nil
}
}
return http.DefaultClient, nil
}
func ContextTransport(ctx context.Context) http.RoundTripper {
hc, err := ContextClient(ctx)
// This is a rare error case (somebody using nil on App Engine).
if err != nil {
return ErrorTransport{err}
}
return hc.Transport
}
// ErrorTransport returns the specified error on RoundTrip.
// This RoundTripper should be used in rare error cases where
// error handling can be postponed to response handling time.
type ErrorTransport struct{ Err error }
func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) {
return nil, t.Err
} }

61
vendor/golang.org/x/oauth2/oauth2.go generated vendored
View File

@ -3,19 +3,20 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package oauth2 provides support for making // Package oauth2 provides support for making
// OAuth2 authorized and authenticated HTTP requests. // OAuth2 authorized and authenticated HTTP requests,
// as specified in RFC 6749.
// It can additionally grant authorization with Bearer JWT. // It can additionally grant authorization with Bearer JWT.
package oauth2 // import "golang.org/x/oauth2" package oauth2 // import "golang.org/x/oauth2"
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"golang.org/x/net/context"
"golang.org/x/oauth2/internal" "golang.org/x/oauth2/internal"
) )
@ -117,21 +118,30 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
// that asks for permissions for the required scopes explicitly. // that asks for permissions for the required scopes explicitly.
// //
// State is a token to protect the user from CSRF attacks. You must // State is a token to protect the user from CSRF attacks. You must
// always provide a non-zero string and validate that it matches the // always provide a non-empty string and validate that it matches the
// the state query parameter on your redirect callback. // the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
// //
// Opts may include AccessTypeOnline or AccessTypeOffline, as well // Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce. // as ApprovalForce.
// It can also be used to pass the PKCE challange.
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string { func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString(c.Endpoint.AuthURL) buf.WriteString(c.Endpoint.AuthURL)
v := url.Values{ v := url.Values{
"response_type": {"code"}, "response_type": {"code"},
"client_id": {c.ClientID}, "client_id": {c.ClientID},
"redirect_uri": internal.CondVal(c.RedirectURL), }
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), if c.RedirectURL != "" {
"state": internal.CondVal(state), v.Set("redirect_uri", c.RedirectURL)
}
if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
if state != "" {
// TODO(light): Docs say never to omit state; don't allow empty.
v.Set("state", state)
} }
for _, opt := range opts { for _, opt := range opts {
opt.setValue(v) opt.setValue(v)
@ -157,12 +167,15 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
// The HTTP client to use is derived from the context. // The HTTP client to use is derived from the context.
// If nil, http.DefaultClient is used. // If nil, http.DefaultClient is used.
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
return retrieveToken(ctx, c, url.Values{ v := url.Values{
"grant_type": {"password"}, "grant_type": {"password"},
"username": {username}, "username": {username},
"password": {password}, "password": {password},
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), }
}) if len(c.Scopes) > 0 {
v.Set("scope", strings.Join(c.Scopes, " "))
}
return retrieveToken(ctx, c, v)
} }
// Exchange converts an authorization code into a token. // Exchange converts an authorization code into a token.
@ -175,13 +188,21 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
// //
// The code will be in the *http.Request.FormValue("code"). Before // The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state"). // calling Exchange, be sure to validate FormValue("state").
func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) { //
return retrieveToken(ctx, c, url.Values{ // Opts may include the PKCE verifier code if previously used in AuthCodeURL.
// See https://www.oauth.com/oauth2-servers/pkce/ for more info.
func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOption) (*Token, error) {
v := url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
"code": {code}, "code": {code},
"redirect_uri": internal.CondVal(c.RedirectURL), }
"scope": internal.CondVal(strings.Join(c.Scopes, " ")), if c.RedirectURL != "" {
}) v.Set("redirect_uri", c.RedirectURL)
}
for _, opt := range opts {
opt.setValue(v)
}
return retrieveToken(ctx, c, v)
} }
// Client returns an HTTP client using the provided token. // Client returns an HTTP client using the provided token.
@ -292,20 +313,20 @@ var HTTPClient internal.ContextKey
// NewClient creates an *http.Client from a Context and TokenSource. // NewClient creates an *http.Client from a Context and TokenSource.
// The returned client is not valid beyond the lifetime of the context. // The returned client is not valid beyond the lifetime of the context.
// //
// Note that if a custom *http.Client is provided via the Context it
// is used only for token acquisition and is not used to configure the
// *http.Client returned from NewClient.
//
// As a special case, if src is nil, a non-OAuth2 client is returned // As a special case, if src is nil, a non-OAuth2 client is returned
// using the provided context. This exists to support related OAuth2 // using the provided context. This exists to support related OAuth2
// packages. // packages.
func NewClient(ctx context.Context, src TokenSource) *http.Client { func NewClient(ctx context.Context, src TokenSource) *http.Client {
if src == nil { if src == nil {
c, err := internal.ContextClient(ctx) return internal.ContextClient(ctx)
if err != nil {
return &http.Client{Transport: internal.ErrorTransport{Err: err}}
}
return c
} }
return &http.Client{ return &http.Client{
Transport: &Transport{ Transport: &Transport{
Base: internal.ContextTransport(ctx), Base: internal.ContextClient(ctx).Transport,
Source: ReuseTokenSource(nil, src), Source: ReuseTokenSource(nil, src),
}, },
} }

23
vendor/golang.org/x/oauth2/token.go generated vendored
View File

@ -5,13 +5,14 @@
package oauth2 package oauth2
import ( import (
"context"
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"golang.org/x/net/context"
"golang.org/x/oauth2/internal" "golang.org/x/oauth2/internal"
) )
@ -20,7 +21,7 @@ import (
// expirations due to client-server time mismatches. // expirations due to client-server time mismatches.
const expiryDelta = 10 * time.Second const expiryDelta = 10 * time.Second
// Token represents the crendentials used to authorize // Token represents the credentials used to authorize
// the requests to access protected resources on the OAuth 2.0 // the requests to access protected resources on the OAuth 2.0
// provider's backend. // provider's backend.
// //
@ -123,7 +124,7 @@ func (t *Token) expired() bool {
if t.Expiry.IsZero() { if t.Expiry.IsZero() {
return false return false
} }
return t.Expiry.Add(-expiryDelta).Before(time.Now()) return t.Expiry.Round(0).Add(-expiryDelta).Before(time.Now())
} }
// Valid reports whether t is non-nil, has an AccessToken, and is not expired. // Valid reports whether t is non-nil, has an AccessToken, and is not expired.
@ -152,7 +153,23 @@ func tokenFromInternal(t *internal.Token) *Token {
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v) tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v)
if err != nil { if err != nil {
if rErr, ok := err.(*internal.RetrieveError); ok {
return nil, (*RetrieveError)(rErr)
}
return nil, err return nil, err
} }
return tokenFromInternal(tk), nil return tokenFromInternal(tk), nil
} }
// RetrieveError is the error returned when the token endpoint returns a
// non-2XX HTTP status code.
type RetrieveError struct {
Response *http.Response
// Body is the body that was consumed by reading Response.Body.
// It may be truncated.
Body []byte
}
func (r *RetrieveError) Error() string {
return fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", r.Response.Status, r.Body)
}

View File

@ -31,9 +31,17 @@ type Transport struct {
} }
// RoundTrip authorizes and authenticates the request with an // RoundTrip authorizes and authenticates the request with an
// access token. If no token exists or token is expired, // access token from Transport's Source.
// tries to refresh/fetch a new token.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
reqBodyClosed := false
if req.Body != nil {
defer func() {
if !reqBodyClosed {
req.Body.Close()
}
}()
}
if t.Source == nil { if t.Source == nil {
return nil, errors.New("oauth2: Transport's Source is nil") return nil, errors.New("oauth2: Transport's Source is nil")
} }
@ -46,6 +54,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
token.SetAuthHeader(req2) token.SetAuthHeader(req2)
t.setModReq(req, req2) t.setModReq(req, req2)
res, err := t.base().RoundTrip(req2) res, err := t.base().RoundTrip(req2)
// req.Body is assumed to have been closed by the base RoundTripper.
reqBodyClosed = true
if err != nil { if err != nil {
t.setModReq(req, nil) t.setModReq(req, nil)
return nil, err return nil, err