diff --git a/driver/docker-container/driver.go b/driver/docker-container/driver.go index 90eadf4b..45f77b31 100644 --- a/driver/docker-container/driver.go +++ b/driver/docker-container/driver.go @@ -56,6 +56,7 @@ type Driver struct { restartPolicy container.RestartPolicy env []string defaultLoad bool + gpus []container.DeviceRequest } func (d *Driver) IsMobyDriver() bool { @@ -158,6 +159,9 @@ func (d *Driver) create(ctx context.Context, l progress.SubLogger) error { if d.cpusetMems != "" { hc.Resources.CpusetMems = d.cpusetMems } + if len(d.gpus) > 0 && d.hasGPUCapability(ctx, cfg.Image, d.gpus) { + hc.Resources.DeviceRequests = d.gpus + } if info, err := d.DockerAPI.Info(ctx); err == nil { if info.CgroupDriver == "cgroupfs" { // Place all buildkit containers inside this cgroup by default so limits can be attached @@ -429,6 +433,31 @@ func (d *Driver) HostGatewayIP(ctx context.Context) (net.IP, error) { return nil, errors.New("host-gateway is not supported by the docker-container driver") } +// hasGPUCapability checks if docker daemon has GPU capability. We need to run +// a dummy container with GPU device to check if the daemon has this capability +// because there is no API to check it yet. +func (d *Driver) hasGPUCapability(ctx context.Context, image string, gpus []container.DeviceRequest) bool { + cfg := &container.Config{ + Image: image, + Entrypoint: []string{"/bin/true"}, + } + hc := &container.HostConfig{ + NetworkMode: container.NetworkMode(container.IPCModeNone), + AutoRemove: true, + Resources: container.Resources{ + DeviceRequests: gpus, + }, + } + resp, err := d.DockerAPI.ContainerCreate(ctx, cfg, hc, &network.NetworkingConfig{}, nil, "") + if err != nil { + return false + } + if err := d.DockerAPI.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { + return false + } + return true +} + func demuxConn(c net.Conn) net.Conn { pr, pw := io.Pipe() // TODO: rewrite parser with Reader() to avoid goroutine switch diff --git a/driver/docker-container/factory.go b/driver/docker-container/factory.go index 18bb2a4a..dba791b6 100644 --- a/driver/docker-container/factory.go +++ b/driver/docker-container/factory.go @@ -51,6 +51,12 @@ func (f *factory) New(ctx context.Context, cfg driver.InitConfig) (driver.Driver InitConfig: cfg, restartPolicy: rp, } + var gpus dockeropts.GpuOpts + if err := gpus.Set("all"); err == nil { + if v := gpus.Value(); len(v) > 0 { + d.gpus = v + } + } for k, v := range cfg.DriverOpts { switch { case k == "network":