Newer
Older
pokemon-go-trade / vendor / github.com / go-pg / pg / orm / join.go
package orm

import (
	"reflect"

	"github.com/go-pg/pg/internal"
	"github.com/go-pg/pg/types"
)

type join struct {
	Parent    *join
	BaseModel TableModel
	JoinModel TableModel
	Rel       *Relation

	ApplyQuery func(*Query) (*Query, error)
	Columns    []string
	on         []*condAppender
}

func (j *join) AppendOn(app *condAppender) {
	j.on = append(j.on, app)
}

func (j *join) Select(q *Query) error {
	switch j.Rel.Type {
	case HasManyRelation:
		return j.selectMany(q)
	case Many2ManyRelation:
		return j.selectM2M(q)
	}
	panic("not reached")
}

func (j *join) selectMany(q *Query) error {
	q, err := j.manyQuery(q)
	if err != nil {
		return err
	}
	if q == nil {
		return nil
	}
	return q.Select()
}

func (j *join) manyQuery(q *Query) (*Query, error) {
	manyModel := newManyModel(j)
	if manyModel == nil {
		return nil, nil
	}

	q = q.Model(manyModel)
	if j.ApplyQuery != nil {
		var err error
		q, err = j.ApplyQuery(q)
		if err != nil {
			return nil, err
		}
	}

	if len(q.columns) == 0 {
		q.columns = append(q.columns, hasManyColumnsAppender{j})
	}

	baseTable := j.BaseModel.Table()
	var where []byte
	if len(j.Rel.FKs) > 1 {
		where = append(where, '(')
	}
	where = appendColumns(where, j.JoinModel.Table().Alias, j.Rel.FKs)
	if len(j.Rel.FKs) > 1 {
		where = append(where, ')')
	}
	where = append(where, " IN ("...)
	where = appendChildValues(
		where, j.JoinModel.Root(), j.JoinModel.ParentIndex(), j.Rel.FKValues)
	where = append(where, ")"...)
	q = q.Where(internal.BytesToString(where))

	if j.Rel.Polymorphic != nil {
		q = q.Where(`? IN (?, ?)`,
			j.Rel.Polymorphic.Column,
			baseTable.ModelName, baseTable.TypeName)
	}

	return q, nil
}

func (j *join) selectM2M(q *Query) error {
	q, err := j.m2mQuery(q)
	if err != nil {
		return err
	}
	if q == nil {
		return nil
	}
	return q.Select()
}

func (j *join) m2mQuery(q *Query) (*Query, error) {
	m2mModel := newM2MModel(j)
	if m2mModel == nil {
		return nil, nil
	}

	q = q.Model(m2mModel)
	if j.ApplyQuery != nil {
		var err error
		q, err = j.ApplyQuery(q)
		if err != nil {
			return nil, err
		}
	}

	if len(q.columns) == 0 {
		q.columns = append(q.columns, hasManyColumnsAppender{j})
	}

	index := j.JoinModel.ParentIndex()
	baseTable := j.BaseModel.Table()
	var join []byte
	join = append(join, "JOIN "...)
	join = q.FormatQuery(join, string(j.Rel.M2MTableName))
	join = append(join, " AS "...)
	join = append(join, j.Rel.M2MTableAlias...)
	join = append(join, " ON ("...)
	for i, col := range j.Rel.BaseFKs {
		if i > 0 {
			join = append(join, ", "...)
		}
		join = append(join, j.Rel.M2MTableAlias...)
		join = append(join, '.')
		join = types.AppendField(join, col, 1)
	}
	join = append(join, ") IN ("...)
	join = appendChildValues(join, j.BaseModel.Root(), index, baseTable.PKs)
	join = append(join, ")"...)
	q = q.Join(internal.BytesToString(join))

	joinTable := j.JoinModel.Table()
	for i, col := range j.Rel.JoinFKs {
		if i >= len(joinTable.PKs) {
			break
		}
		pk := joinTable.PKs[i]
		q = q.Where("?.? = ?.?",
			joinTable.Alias, pk.Column,
			j.Rel.M2MTableAlias, types.F(col))
	}

	return q, nil
}

func (j *join) hasParent() bool {
	if j.Parent != nil {
		switch j.Parent.Rel.Type {
		case HasOneRelation, BelongsToRelation:
			return true
		}
	}
	return false
}

func (j *join) appendAlias(b []byte) []byte {
	b = append(b, '"')
	b = appendAlias(b, j, true)
	b = append(b, '"')
	return b
}

func (j *join) appendAliasColumn(b []byte, column string) []byte {
	b = append(b, '"')
	b = appendAlias(b, j, true)
	b = append(b, "__"...)
	b = types.AppendField(b, column, 0)
	b = append(b, '"')
	return b
}

func (j *join) appendBaseAlias(b []byte) []byte {
	if j.hasParent() {
		b = append(b, '"')
		b = appendAlias(b, j.Parent, true)
		b = append(b, '"')
		return b
	}
	return append(b, j.BaseModel.Table().Alias...)
}

func appendAlias(b []byte, j *join, topLevel bool) []byte {
	if j.hasParent() {
		b = appendAlias(b, j.Parent, topLevel)
		topLevel = false
	}
	if !topLevel {
		b = append(b, "__"...)
	}
	b = append(b, j.Rel.Field.SQLName...)
	return b
}

func (j *join) appendHasOneColumns(b []byte) []byte {
	if j.Columns == nil {
		for i, f := range j.JoinModel.Table().Fields {
			if i > 0 {
				b = append(b, ", "...)
			}
			b = j.appendAlias(b)
			b = append(b, '.')
			b = append(b, f.Column...)
			b = append(b, " AS "...)
			b = j.appendAliasColumn(b, f.SQLName)
		}
		return b
	}

	for i, column := range j.Columns {
		if i > 0 {
			b = append(b, ", "...)
		}
		b = j.appendAlias(b)
		b = append(b, '.')
		b = types.AppendField(b, column, 1)
		b = append(b, " AS "...)
		b = j.appendAliasColumn(b, column)
	}

	return b
}

func (j *join) appendHasOneJoin(q *Query, b []byte) []byte {
	isSoftDelete := q.isSoftDelete()

	b = append(b, "LEFT JOIN "...)
	b = q.FormatQuery(b, string(j.JoinModel.Table().FullNameForSelects))
	b = append(b, " AS "...)
	b = j.appendAlias(b)

	b = append(b, " ON "...)

	if isSoftDelete {
		b = append(b, '(')
	}

	if len(j.Rel.FKs) > 1 {
		b = append(b, '(')
	}
	if j.Rel.Type == HasOneRelation {
		joinTable := j.Rel.JoinTable
		for i, fk := range j.Rel.FKs {
			if i > 0 {
				b = append(b, " AND "...)
			}
			b = j.appendAlias(b)
			b = append(b, '.')
			b = append(b, joinTable.PKs[i].Column...)
			b = append(b, " = "...)
			b = j.appendBaseAlias(b)
			b = append(b, '.')
			b = append(b, fk.Column...)
		}
	} else {
		baseTable := j.BaseModel.Table()
		for i, fk := range j.Rel.FKs {
			if i > 0 {
				b = append(b, " AND "...)
			}
			b = j.appendAlias(b)
			b = append(b, '.')
			b = append(b, fk.Column...)
			b = append(b, " = "...)
			b = j.appendBaseAlias(b)
			b = append(b, '.')
			b = append(b, baseTable.PKs[i].Column...)
		}
	}
	if len(j.Rel.FKs) > 1 {
		b = append(b, ')')
	}

	for _, on := range j.on {
		b = on.AppendSep(b)
		b = on.AppendFormat(b, q)
	}

	if isSoftDelete {
		b = append(b, ')')
	}

	if isSoftDelete {
		b = append(b, " AND "...)
		b = j.appendBaseAlias(b)
		b = q.appendSoftDelete(b)
	}

	return b
}

type hasManyColumnsAppender struct {
	*join
}

func (q hasManyColumnsAppender) AppendFormat(b []byte, fmter QueryFormatter) []byte {
	if q.Rel.M2MTableAlias != "" {
		b = append(b, q.Rel.M2MTableAlias...)
		b = append(b, ".*, "...)
	}

	joinTable := q.JoinModel.Table()

	if q.Columns != nil {
		for i, column := range q.Columns {
			if i > 0 {
				b = append(b, ", "...)
			}
			b = append(b, joinTable.Alias...)
			b = append(b, '.')
			b = types.AppendField(b, column, 1)
		}
		return b
	}

	return appendColumns(b, joinTable.Alias, joinTable.Fields)
}

func appendChildValues(b []byte, v reflect.Value, index []int, fields []*Field) []byte {
	seen := make(map[string]struct{})
	walk(v, index, func(v reflect.Value) {
		start := len(b)

		if len(fields) > 1 {
			b = append(b, '(')
		}
		for i, f := range fields {
			if i > 0 {
				b = append(b, ", "...)
			}
			b = f.AppendValue(b, v, 1)
		}
		if len(fields) > 1 {
			b = append(b, ')')
		}
		b = append(b, ", "...)

		if _, ok := seen[string(b[start:])]; ok {
			b = b[:start]
		} else {
			seen[string(b[start:])] = struct{}{}
		}
	})
	if len(seen) > 0 {
		b = b[:len(b)-2] // trim ", "
	}
	return b
}