常见问题解答(FAQ)
问题
如果由一个结构体 T 创建一个实体?
如何创建结构体(或突变)级别验证器?
如何编写一个日志审计扩展?
如何编写自定义谓词?
如果将自定义谓词添加到代码生成资产中?
如何在 PostgreSQL 中定义网络地址字段?
如何在 MySQL 中将时间字段自定义为 DATETIME?
如何为 ID 使用自定义生成器?
如何使用 XID 自定义全球唯一 ID?
如何在 MySQL 中定义空间数据类型字段?
如何扩展已生成的模型?
如果扩展已生成的构建器?
如何在 BLOB 列中存储 Protobuf 对象?
如何为表添加 CHECK 约束?
如何定义自定义精度的数字字段?
如何配置两个及以上 DB 实现读写分离?
如何配置 json.Marshal 以内联顶级对象中的 edges 键?
解答
如果由一个结构体 T 创建一个实体?
不同的构建器不支持通过给定的结构体 T 设置实体字段(或边)的选项。
原因是当更新数据库时(例如 &ent.T{Age: 0, Name: ""})无法区分零值与实际值。
设置这些值可能导致在数据库中设置错误的值或更新非必要列。
然而 外部模板 选项可以让你通过添加自定义逻辑扩展代码生成资产。 例如,若要每个创建的构建器生成方法,接受结构体作为输入并配置构建器,可以使用如下模板:
{{ range $n := $.Nodes }}
{{ $builder := $n.CreateName }}
{{ $receiver := $n.CreateReceiver }}
func ({{ $receiver }} *{{ $builder }}) Set{{ $n.Name }}(input *{{ $n.Name }}) *{{ $builder }} {
{{- range $f := $n.Fields }}
{{- $setter := print "Set" $f.StructField }}
{{ $receiver }}.{{ $setter }}(input.{{ $f.StructField }})
{{- end }}
return {{ $receiver }}
}
{{ end }}
如何创建突变级别验证器?
为实现突变级别验证器,你可以使用 模式钩子 来验证应用到一个实体类型的变更, 也可以使用 事务钩子 来校验应用到多个实体类型的突变(例如 GraphQL 突变)。 例如:
// A VersionHook is a dummy example for a hook that validates the "version" field
// is incremented by 1 on each update. Note that this is just a dummy example, and
// it doesn't promise consistency in the database.
func VersionHook() ent.Hook {
type OldSetVersion interface {
SetVersion(int)
Version() (int, bool)
OldVersion(context.Context) (int, error)
}
return func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
ver, ok := m.(OldSetVersion)
if !ok {
return next.Mutate(ctx, m)
}
oldV, err := ver.OldVersion(ctx)
if err != nil {
return nil, err
}
curV, exists := ver.Version()
if !exists {
return nil, fmt.Errorf("version field is required in update mutation")
}
if curV != oldV+1 {
return nil, fmt.Errorf("version field must be incremented by 1")
}
// Add an SQL predicate that validates the "version" column is equal
// to "oldV" (ensure it wasn't changed during the mutation by others).
return next.Mutate(ctx, m)
})
}
}
如何编写一个日志审计扩展?
编写此类扩展的首选方式是使用 ent.Mixin。
通过 Fields 选项设置所有通过导入混合模式的而在模式间共享的字段,并通过 Hooks 选项为这些应用于模式上的所有突变附加突变钩子。
以下示例基于 代码仓库问题跟踪 中的讨论:
// AuditMixin implements the ent.Mixin for sharing
// audit-log capabilities with package schemas.
type AuditMixin struct{
mixin.Schema
}
// Fields of the AuditMixin.
func (AuditMixin) Fields() []ent.Field {
return []ent.Field{
field.Time("created_at").
Immutable().
Default(time.Now),
field.Int("created_by").
Optional(),
field.Time("updated_at").
Default(time.Now).
UpdateDefault(time.Now),
field.Int("updated_by").
Optional(),
}
}
// Hooks of the AuditMixin.
func (AuditMixin) Hooks() []ent.Hook {
return []ent.Hook{
hooks.AuditHook,
}
}
// A AuditHook is an example for audit-log hook.
func AuditHook(next ent.Mutator) ent.Mutator {
// AuditLogger wraps the methods that are shared between all mutations of
// schemas that embed the AuditLog mixin. The variable "exists" is true, if
// the field already exists in the mutation (e.g. was set by a different hook).
type AuditLogger interface {
SetCreatedAt(time.Time)
CreatedAt() (value time.Time, exists bool)
SetCreatedBy(int)
CreatedBy() (id int, exists bool)
SetUpdatedAt(time.Time)
UpdatedAt() (value time.Time, exists bool)
SetUpdatedBy(int)
UpdatedBy() (id int, exists bool)
}
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
ml, ok := m.(AuditLogger)
if !ok {
return nil, fmt.Errorf("unexpected audit-log call from mutation type %T", m)
}
usr, err := viewer.UserFromContext(ctx)
if err != nil {
return nil, err
}
switch op := m.Op(); {
case op.Is(ent.OpCreate):
ml.SetCreatedAt(time.Now())
if _, exists := ml.CreatedBy(); !exists {
ml.SetCreatedBy(usr.ID)
}
case op.Is(ent.OpUpdateOne | ent.OpUpdate):
ml.SetUpdatedAt(time.Now())
if _, exists := ml.UpdatedBy(); !exists {
ml.SetUpdatedBy(usr.ID)
}
}
return next.Mutate(ctx, m)
})
}
如何编写自定义谓词?
用户可以在查询执行前提供自定义谓词应用于查询。例如:
pets := client.Pet.
Query().
Where(predicate.Pet(func(s *sql.Selector) {
s.Where(sql.InInts(pet.OwnerColumn, 1, 2, 3))
})).
AllX(ctx)
users := client.User.
Query().
Where(predicate.User(func(s *sql.Selector) {
s.Where(sqljson.ValueContains(user.FieldTags, "tag"))
})).
AllX(ctx)
更多示例,请参见 谓词(Predicates) 页,或在代码仓库的问题跟踪中获取更进一步的示例, 例如 issue-842。
如果将自定义谓词添加到代码生成资产中?
模板 选项能够使扩展具备覆盖默认代码生成资产的能力。 为在 上一个示例中 生成类型安全的谓词,可像如下使用模板选项:
{{/* A template that adds the "<F>Glob" predicate for all string fields. */}}
{{ define "where/additional/strings" }}
{{ range $f := $.Fields }}
{{ if $f.IsString }}
{{ $func := print $f.StructField "Glob" }}
// {{ $func }} applies the Glob predicate on the {{ quote $f.Name }} field.
func {{ $func }}(pattern string) predicate.{{ $.Name }} {
return predicate.{{ $.Name }}(func(s *sql.Selector) {
s.Where(sql.P(func(b *sql.Builder) {
b.Ident(s.C({{ $f.Constant }})).WriteString(" glob" ).Arg(pattern)
}))
})
}
{{ end }}
{{ end }}
{{ end }}
如何在 PostgreSQL 中定义网络地址字段?
Go 语言类型(GoType) 和 数据库类型(Database Type,原文为 SchemaType)
允许用户定义数据库特定的字段。例如定义 macaddr 字段,
可使用如下配置:
func (T) Fields() []ent.Field {
return []ent.Field{
field.String("mac").
GoType(&MAC{}).
SchemaType(map[string]string{
dialect.Postgres: "macaddr",
}).
Validate(func(s string) error {
_, err := net.ParseMAC(s)
return err
}),
}
}
// MAC represents a physical hardware address.
type MAC struct {
net.HardwareAddr
}
// Scan implements the Scanner interface.
func (m *MAC) Scan(value any) (err error) {
switch v := value.(type) {
case nil:
case []byte:
m.HardwareAddr, err = net.ParseMAC(string(v))
case string:
m.HardwareAddr, err = net.ParseMAC(v)
default:
err = fmt.Errorf("unexpected type %T", v)
}
return
}
// Value implements the driver Valuer interface.
func (m MAC) Value() (driver.Value, error) {
return m.HardwareAddr.String(), nil
}
请注意,如果数据库不支持 macaddr 类型(例如测试环境中的 SQLite),该字段将回退到其原生类型(如 string)。
inet example:
func (T) Fields() []ent.Field {
return []ent.Field{
field.String("ip").
GoType(&Inet{}).
SchemaType(map[string]string{
dialect.Postgres: "inet",
}).
Validate(func(s string) error {
if net.ParseIP(s) == nil {
return fmt.Errorf("invalid value for ip %q", s)
}
return nil
}),
}
}
// Inet represents a single IP address
type Inet struct {
net.IP
}
// Scan implements the Scanner interface
func (i *Inet) Scan(value any) (err error) {
switch v := value.(type) {
case nil:
case []byte:
if i.IP = net.ParseIP(string(v)); i.IP == nil {
err = fmt.Errorf("invalid value for ip %q", v)
}
case string:
if i.IP = net.ParseIP(v); i.IP == nil {
err = fmt.Errorf("invalid value for ip %q", v)
}
default:
err = fmt.Errorf("unexpected type %T", v)
}
return
}
// Value implements the driver Valuer interface
func (i Inet) Value() (driver.Value, error) {
return i.IP.String(), nil
}
如何在 MySQL 中将时间字段自定义为 DATETIME?
Time 字段默认在模式创建中使用 MySQL 的 TIMESTAMP 类型,范围是 '1970-01-01 00:00:01' UTC 到 '2038-01-19 03:14:07' UTC(参见 MySQL 文档)。
若自定义时间字段为更广的范围,可像如下使用 MySQL 的 DATETIME:
field.Time("birth_date").
Optional().
SchemaType(map[string]string{
dialect.MySQL: "datetime",
}),
如何为 ID 使用自定义生成器?
如果你在数据库中使用自定义 ID(例如 Snowflake)生成器而非自增 ID, 你需要编写一个自定义 ID 字段,该字段在资源创建时自动调用生成器。
为达到此目的,你可以使用 DefaultFunc 或模式钩子 —— 这取决于你的使用场景。
如果生成器无错误返回,DefaultFunc 会更简洁,否则在资源创建时设置钩子可以允许你捕获错误。
如何使用 DefaultFunc 的示例可参见 ID 字段 部分。
下面是使用钩子创建 sonyflake 的自定义生成器示例。
// BaseMixin to be shared will all different schemas.
type BaseMixin struct {
mixin.Schema
}
// Fields of the Mixin.
func (BaseMixin) Fields() []ent.Field {
return []ent.Field{
field.Uint64("id"),
}
}
// Hooks of the Mixin.
func (BaseMixin) Hooks() []ent.Hook {
return []ent.Hook{
hook.On(IDHook(), ent.OpCreate),
}
}
func IDHook() ent.Hook {
sf := sonyflake.NewSonyflake(sonyflake.Settings{})
type IDSetter interface {
SetID(uint64)
}
return func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
is, ok := m.(IDSetter)
if !ok {
return nil, fmt.Errorf("unexpected mutation %T", m)
}
id, err := sf.NextID()
if err != nil {
return nil, err
}
is.SetID(id)
return next.Mutate(ctx, m)
})
}
}
// User holds the schema definition for the User entity.
type User struct {
ent.Schema
}
// Mixin of the User.
func (User) Mixin() []ent.Mixin {
return []ent.Mixin{
// Embed the BaseMixin in the user schema.
BaseMixin{},
}
}
如何使用 XID 自定义全球唯一 ID?
xid 是使用 Mongo Object ID 算法无需配置即可生成生成 12 比特、20字符 ID 的全球唯一 ID 生成器。Ent 要求 xid 包使用 database/sql 的 sql.Scanner 和 driver.Valuer 接口进行序列化。
使用 Go 语言类型(GoType) 模式配置在任意字符串字段中存储 XID:
// Fields of type T.
func (T) Fields() []ent.Field {
return []ent.Field{
field.String("id").
GoType(xid.ID{}).
DefaultFunc(xid.New),
}
}
或作为 混合(Mixin) 在多个模式中重复使用:
package schema
import (
"entgo.io/ent"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
"github.com/rs/xid"
)
// BaseMixin to be shared will all different schemas.
type BaseMixin struct {
mixin.Schema
}
// Fields of the User.
func (BaseMixin) Fields() []ent.Field {
return []ent.Field{
field.String("id").
GoType(xid.ID{}).
DefaultFunc(xid.New),
}
}
// User holds the schema definition for the User entity.
type User struct {
ent.Schema
}
// Mixin of the User.
func (User) Mixin() []ent.Mixin {
return []ent.Mixin{
// Embed the BaseMixin in the user schema.
BaseMixin{},
}
}
若使用 gqlgen 的扩展标识符(XID),请遵循 问题跟踪器 中提到的配置。
如何在 MySQL 中定义空间数据类型字段?
Go 语言类型(GoType) 和 数据库类型(Database Type,原文为 SchemaType) 选项允许用户定义数据库特定字段。
例如定义一个 POINT 字段可以使用如下配置:
// Fields of the Location.
func (Location) Fields() []ent.Field {
return []ent.Field{
field.String("name"),
field.Other("coords", &Point{}).
SchemaType(Point{}.SchemaType()),
}
}
package schema
import (
"database/sql/driver"
"fmt"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"github.com/paulmach/orb"
"github.com/paulmach/orb/encoding/wkb"
)
// A Point consists of (X,Y) or (Lat, Lon) coordinates
// and it is stored in MySQL the POINT spatial data type.
type Point [2]float64
// Scan implements the Scanner interface.
func (p *Point) Scan(value any) error {
bin, ok := value.([]byte)
if !ok {
return fmt.Errorf("invalid binary value for point")
}
var op orb.Point
if err := wkb.Scanner(&op).Scan(bin[4:]); err != nil {
return err
}
p[0], p[1] = op.X(), op.Y()
return nil
}
// Value implements the driver Valuer interface.
func (p Point) Value() (driver.Value, error) {
op := orb.Point{p[0], p[1]}
return wkb.Value(op).Value()
}
// FormatParam implements the sql.ParamFormatter interface to tell the SQL
// builder that the placeholder for a Point parameter needs to be formatted.
func (p Point) FormatParam(placeholder string, info *sql.StmtInfo) string {
if info.Dialect == dialect.MySQL {
return "ST_GeomFromWKB(" + placeholder + ")"
}
return placeholder
}
// SchemaType defines the schema-type of the Point object.
func (Point) SchemaType() map[string]string {
return map[string]string{
dialect.MySQL: "POINT",
}
}
完整示例参见 示例代码仓库。
如何扩展已生成的模型?
Ent 支持使用自定义模板扩展已生成的类型(全局类型和模型)。
例如向已生成的模型中添加额外的结构体字段或方法,可以像 此示例 一样覆盖 model/fields/additional 模板。
如果自定义字段或方法要求额外的导入,同样可以使用自定义模板添加这些导入:
{{- define "import/additional/field_types" -}}
"github.com/path/to/your/custom/type"
{{- end -}}
{{- define "import/additional/client_dependencies" -}}
"github.com/path/to/your/custom/type"
{{- end -}}
如果扩展已生成的构建器?
如何在 BLOB 列中存储 Protobuf 对象?
假设我们有一个定义如下的 Protobuf 消息:
syntax = "proto3";
package pb;
option go_package = "project/pb";
message Hi {
string Greeting = 1;
}
我们向生成的 protobuf 结构体中添加消息接受方法,此方法实现 ValueScanner 接口。
func (x *Hi) Value() (driver.Value, error) {
return proto.Marshal(x)
}
func (x *Hi) Scan(src any) error {
if src == nil {
return nil
}
if b, ok := src.([]byte); ok {
if err := proto.Unmarshal(b, x); err != nil {
return err
}
return nil
}
return fmt.Errorf("unexpected type %T", src)
}
我们向模式中添加一个新的 field.Bytes,设置生成的底层为 GoType 的 protobuf 结构体:
// Fields of the Message.
func (Message) Fields() []ent.Field {
return []ent.Field{
field.Bytes("hi").
GoType(&pb.Hi{}),
}
}
测试是否有效:
package main
import (
"context"
"testing"
"project/ent/enttest"
"project/pb"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
func TestMain(t *testing.T) {
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
msg := client.Message.Create().
SetHi(&pb.Hi{
Greeting: "hello",
}).
SaveX(context.TODO())
ret := client.Message.GetX(context.TODO(), msg.ID)
require.Equal(t, "hello", ret.Hi.Greeting)
}
如何为表添加 CHECK 约束?
entsql.Annotation 选项允许将自定义 CHECK 约束添加到 CREATE TABLE 语句中。
像如下示例将 CHECK 约束添加到模式中:
func (User) Annotations() []schema.Annotation {
return []schema.Annotation{
&entsql.Annotation{
// The `Check` option allows adding an
// unnamed CHECK constraint to table DDL.
Check: "website <> 'entgo.io'",
// The `Checks` option allows adding multiple CHECK constraints
// to table creation. The keys are used as the constraint names.
Checks: map[string]string{
"valid_nickname": "nickname <> firstname",
"valid_firstname": "length(first_name) > 1",
},
},
}
}
如何定义自定义精度的数字字段?
使用 Go 语言类型(GoType) 和 数据库类型(Database Type,原文为 SchemaType) 来定义自定义精度的数字字段。 例如使用 big.Int 定义一个字段:
func (T) Fields() []ent.Field {
return []ent.Field{
field.Int("precise").
GoType(new(BigInt)).
SchemaType(map[string]string{
dialect.SQLite: "numeric(78, 0)",
dialect.Postgres: "numeric(78, 0)",
}),
}
}
type BigInt struct {
big.Int
}
func (b *BigInt) Scan(src any) error {
var i sql.NullString
if err := i.Scan(src); err != nil {
return err
}
if !i.Valid {
return nil
}
if _, ok := b.Int.SetString(i.String, 10); ok {
return nil
}
return fmt.Errorf("could not scan type %T with value %v into BigInt", src, src)
}
func (b *BigInt) Value() (driver.Value, error) {
return b.String(), nil
}
如何配置两个及以上 DB 实现读写分离?
可以用自己的驱动程序封装 dialect.Driver 并实现相应逻辑。例如:
你可以扩展它,添加对多个读副本的支持,并加入一些负载均衡的魔法。
func main() {
// ...
wd, err := sql.Open(dialect.MySQL, "root:pass@tcp(<addr>)/<database>?parseTime=True")
if err != nil {
log.Fatal(err)
}
rd, err := sql.Open(dialect.MySQL, "readonly:pass@tcp(<addr>)/<database>?parseTime=True")
if err != nil {
log.Fatal(err)
}
client := ent.NewClient(ent.Driver(&multiDriver{w: wd, r: rd}))
defer client.Close()
// Use the client here.
}
type multiDriver struct {
r, w dialect.Driver
}
var _ dialect.Driver = (*multiDriver)(nil)
func (d *multiDriver) Query(ctx context.Context, query string, args, v any) error {
e := d.r
// Mutation statements that use the RETURNING clause.
if ent.QueryFromContext(ctx) == nil {
e = d.w
}
return e.Query(ctx, query, args, v)
}
func (d *multiDriver) Exec(ctx context.Context, query string, args, v any) error {
return d.w.Exec(ctx, query, args, v)
}
func (d *multiDriver) Tx(ctx context.Context) (dialect.Tx, error) {
return d.w.Tx(ctx)
}
func (d *multiDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) {
return d.w.(interface {
BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error)
}).BeginTx(ctx, opts)
}
func (d *multiDriver) Close() error {
rerr := d.r.Close()
werr := d.w.Close()
if rerr != nil {
return rerr
}
if werr != nil {
return werr
}
return nil
}
func (d *multiDriver) Dialect() string {
return d.r.Dialect()
}
如何配置 json.Marshal 以内联顶级对象中的 edges 键?
按以下两步编码没有 edges 属性的实体:
- 忽略 Ent 生成的默认
edges标签。 - 用自定义 MarshalJSON 方法扩展生成的模型。
这两步可以使用 代码生成扩展 自动完成,完整示例参见 examples/jsonencode 目录。
//go:build ignore
// +build ignore
package main
import (
"log"
"entgo.io/ent/entc"
"entgo.io/ent/entc/gen"
"entgo.io/ent/schema/edge"
)
func main() {
opts := []entc.Option{
entc.Extensions{
&EncodeExtension{},
),
}
err := entc.Generate("./schema", &gen.Config{}, opts...)
if err != nil {
log.Fatalf("running ent codegen: %v", err)
}
}
// EncodeExtension is an implementation of entc.Extension that adds a MarshalJSON
// method to each generated type <T> and inlines the Edges field to the top level JSON.
type EncodeExtension struct {
entc.DefaultExtension
}
// Templates of the extension.
func (e *EncodeExtension) Templates() []*gen.Template {
return []*gen.Template{
gen.MustParse(gen.NewTemplate("model/additional/jsonencode").
Parse(`
{{ if $.Edges }}
// MarshalJSON implements the json.Marshaler interface.
func ({{ $.Receiver }} *{{ $.Name }}) MarshalJSON() ([]byte, error) {
type Alias {{ $.Name }}
return json.Marshal(&struct {
*Alias
{{ $.Name }}Edges
}{
Alias: (*Alias)({{ $.Receiver }}),
{{ $.Name }}Edges: {{ $.Receiver }}.Edges,
})
}
{{ end }}
`)),
}
}
// Hooks of the extension.
func (e *EncodeExtension) Hooks() []gen.Hook {
return []gen.Hook{
func(next gen.Generator) gen.Generator {
return gen.GenerateFunc(func(g *gen.Graph) error {
tag := edge.Annotation{StructTag: `json:"-"`}
for _, n := range g.Nodes {
n.Annotations.Set(tag.Name(), tag)
}
return next.Generate(g)
})
},
}
}