Compare commits

...

12 Commits

Author SHA1 Message Date
Christopher Allen Lane
cab039a9d8 docs: move ADRs to project root, remove boilerplate README
Move `doc/adr/` to `adr/` for discoverability. Remove the generic
ADR README — `ls adr/` serves the same purpose.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 07:32:40 -05:00
Christopher Allen Lane
97e80beceb fix: match .git as complete path component, not suffix
Searching for `.git/` in file paths incorrectly matched directory names
ending with `.git` (e.g., `personal.git/cheat/hello`), causing sheets
under such paths to be silently skipped. Fix by requiring the path
separator on both sides (`/.git/`), so `.git` is only matched as a
complete path component.

Rewrites test suite with comprehensive coverage for all six documented
edge cases, including the #711 scenario and combination cases (e.g.,
a real .git directory inside a .git-suffixed parent).

Closes #711

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 07:32:35 -05:00
Christopher Allen Lane
1969423b5c fix: respect $VISUAL and $EDITOR env vars at runtime
Previously, env vars were only consulted during config generation
and baked into conf.yml. At runtime, the config file value was
always used, making it impossible to override the editor via
environment variables.

Now the precedence is: $VISUAL > $EDITOR > conf.yml > auto-detect.

Closes #589

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 06:54:42 -05:00
Christopher Allen Lane
4497ce1b84 ci: remove dead Homebrew formula bump workflow
This workflow has been failing for years due to an expired/missing
COMMITTER_TOKEN. Homebrew maintains their own automated version
bump pipeline, making this redundant.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 06:43:01 -05:00
Christopher Allen Lane
5eee02bc40 build: produce static binaries with CGO_ENABLED=0
Eliminates glibc version mismatch errors when running release
binaries on systems with older glibc versions.

Closes #744

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 06:39:20 -05:00
Christopher Allen Lane
2d50c6a6eb chore: bump version to 4.5.1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 06:26:17 -05:00
Christopher Allen Lane
6f919fd675 fix: comment out community cheatpath in --init output (#773)
cheat --init now comments out the community cheatpath by default and
includes a git clone instruction with the resolved path. This prevents
warnings about missing directories when users save the --init output
as their config without also cloning community cheatsheets.

Closes #773

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 06:21:48 -05:00
Christopher Allen Lane
fd1465ee38 fix: avoid stdin buffering bug in installer prompts
Prompt() created a new bufio.NewReader(os.Stdin) on each call, which
buffered all piped input on the first call and left nothing for
subsequent prompts. This made cheat un-scriptable (e.g., piping answers
via printf). Fix by reading one byte at a time from os.Stdin directly.

Also adds an end-to-end integration test for the first-run experience
(regression test for #721, #771, #730) and bumps the Dockerfile to
Go 1.26.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 21:51:30 -05:00
Christopher Allen Lane
00ec2c130d fix: resolve first-run experience errors (#721, #771, #730)
- cmdInit (--init) now substitutes EDITOR_PATH, PAGER_PATH, and
  WORK_PATH instead of leaving them as literal strings
- Installer now substitutes WORK_PATH and always creates personal
  and work directories regardless of community cheatsheet choice
- When community cheatsheets are declined, the community cheatpath
  is commented out in the generated config
- config.New() skips nonexistent cheatpaths with a warning instead
  of hard-erroring on EvalSymlinks failure

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 21:32:20 -05:00
Christopher Allen Lane
8eafa5adfe fix: cross-platform CI test fixes and parse bug fix
- Add .gitattributes to force LF in mock files (Windows autocrlf)
- Fix parse.go: detect line endings from content instead of runtime.GOOS
- Add fail-fast: false to CI matrix; trigger on all branch pushes
- Skip chmod-based tests on Windows (permissions work differently)
- Use filepath.Join for expected paths in Windows path tests
- Use platform-appropriate invalid paths in error tests
- Add Windows absolute path test case for ValidateSheetName
- Skip Unix-specific integration tests on Windows

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 21:31:26 -05:00
Christopher Allen Lane
b604027205 fix: make tests pass on macOS CI runners
- Resolve symlinks in temp dir paths (macOS /var -> /private/var)
- Pre-create non-empty community dir to ensure clone fails reliably
  regardless of network access on CI runners

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 21:05:16 -05:00
Christopher Allen Lane
2a19755804 chore: modernize CI and update Go toolchain
- Bump Go from 1.19 to 1.26 and update all dependencies
- Rewrite CI workflow with matrix strategy (Linux, macOS, Windows)
- Update GitHub Actions to current versions (checkout@v4, setup-go@v5)
- Update CodeQL actions from v1 to v3
- Fix cross-platform bug in mock/path.go (path.Join -> filepath.Join)
- Clean up dependabot config (weekly schedule, remove stale ignore)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 20:58:51 -05:00
688 changed files with 49979 additions and 32269 deletions

3
.gitattributes vendored Normal file
View File

@@ -0,0 +1,3 @@
# Force LF line endings for mock/test data files to ensure consistent
# behavior across platforms (Windows git autocrlf converts to CRLF otherwise)
mocks/** text eol=lf

View File

@@ -3,9 +3,5 @@ updates:
- package-ecosystem: gomod - package-ecosystem: gomod
directory: "/" directory: "/"
schedule: schedule:
interval: daily interval: weekly
open-pull-requests-limit: 10 open-pull-requests-limit: 10
ignore:
- dependency-name: github.com/alecthomas/chroma
versions:
- 0.9.1

View File

@@ -1,46 +1,38 @@
--- ---
name: Go name: CI
on: on:
push: push:
branches: [master]
pull_request:
branches: [master]
jobs: jobs:
# TODO: is it possible to DRY out these jobs? Aside from `runs-on`, they are lint:
# identical. runs-on: ubuntu-latest
# See: https://github.com/actions/runner/issues/1182
build-linux:
runs-on: [ubuntu-latest]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: Set up Go - uses: actions/setup-go@v5
uses: actions/setup-go@v2
with: with:
go-version: 1.19 go-version: stable
- name: Set up Revive (linter) - name: Install revive
run: go get -u github.com/boyter/scc github.com/mgechev/revive run: go install github.com/mgechev/revive@latest
env: - name: Lint
GO111MODULE: "off" run: revive -exclude vendor/... ./...
- name: Build - name: Vet
run: make build run: go vet ./...
- name: Test - name: Check formatting
run: make test run: test -z "$(gofmt -l . | grep -v vendor/)"
build-osx: test:
runs-on: [macos-latest] strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: Set up Go - uses: actions/setup-go@v5
uses: actions/setup-go@v2
with: with:
go-version: 1.19 go-version: stable
- name: Set up Revive (linter)
run: go get -u github.com/boyter/scc github.com/mgechev/revive
env:
GO111MODULE: "off"
- name: Build - name: Build
run: make build run: go build -mod vendor ./cmd/cheat
- name: Test - name: Test
run: make test run: go test ./...

View File

@@ -19,12 +19,12 @@ jobs:
language: [go] language: [go]
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v2 uses: actions/checkout@v4
- name: Initialize CodeQL - name: Initialize CodeQL
uses: github/codeql-action/init@v1 uses: github/codeql-action/init@v3
with: with:
languages: ${{ matrix.language }} languages: ${{ matrix.language }}
- name: Autobuild - name: Autobuild
uses: github/codeql-action/autobuild@v1 uses: github/codeql-action/autobuild@v3
- name: Perform CodeQL Analysis - name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v1 uses: github/codeql-action/analyze@v3

View File

@@ -1,19 +0,0 @@
---
name: homebrew
on:
push:
tags: '*'
jobs:
homebrew:
name: Bump Homebrew formula
runs-on: ubuntu-latest
steps:
- uses: mislav/bump-homebrew-formula-action@v1
with:
# A PR will be sent to github.com/Homebrew/homebrew-core to update
# this formula:
formula-name: cheat
env:
COMMITTER_TOKEN: ${{ secrets.COMMITTER_TOKEN }}

View File

@@ -1,7 +1,7 @@
# NB: this image isn't used anywhere in the build pipeline. It exists to # NB: this image isn't used anywhere in the build pipeline. It exists to
# conveniently facilitate ad-hoc experimentation in a sandboxed environment # conveniently facilitate ad-hoc experimentation in a sandboxed environment
# during development. # during development.
FROM golang:1.15-alpine FROM golang:1.26-alpine
RUN apk add git less make RUN apk add git less make

View File

@@ -9,13 +9,13 @@ On Unix-like systems, you may simply paste the following snippet into your termi
```sh ```sh
cd /tmp \ cd /tmp \
&& wget https://github.com/cheat/cheat/releases/download/4.5.0/cheat-linux-amd64.gz \ && wget https://github.com/cheat/cheat/releases/download/4.5.1/cheat-linux-amd64.gz \
&& gunzip cheat-linux-amd64.gz \ && gunzip cheat-linux-amd64.gz \
&& chmod +x cheat-linux-amd64 \ && chmod +x cheat-linux-amd64 \
&& sudo mv cheat-linux-amd64 /usr/local/bin/cheat && sudo mv cheat-linux-amd64 /usr/local/bin/cheat
``` ```
You may need to need to change the version number (`4.5.0`) and the archive You may need to need to change the version number (`4.5.1`) and the archive
(`cheat-linux-amd64.gz`) depending on your platform. (`cheat-linux-amd64.gz`) depending on your platform.
See the [releases page][releases] for a list of supported platforms. See the [releases page][releases] for a list of supported platforms.

View File

@@ -27,6 +27,7 @@ ZIP := zip -m
docker_image := cheat-devel:latest docker_image := cheat-devel:latest
# build flags # build flags
export CGO_ENABLED := 0
BUILD_FLAGS := -ldflags="-s -w" -mod vendor -trimpath BUILD_FLAGS := -ldflags="-s -w" -mod vendor -trimpath
GOBIN := GOBIN :=
TMPDIR := /tmp TMPDIR := /tmp

View File

@@ -44,13 +44,37 @@ func cmdInit() {
confpath := confpaths[0] confpath := confpaths[0]
confdir := filepath.Dir(confpath) confdir := filepath.Dir(confpath)
// create paths for community and personal cheatsheets // create paths for community, personal, and work cheatsheets
community := filepath.Join(confdir, "cheatsheets", "community") community := filepath.Join(confdir, "cheatsheets", "community")
personal := filepath.Join(confdir, "cheatsheets", "personal") personal := filepath.Join(confdir, "cheatsheets", "personal")
work := filepath.Join(confdir, "cheatsheets", "work")
// template the above paths into the default configs // template the above paths into the default configs
configs = strings.Replace(configs, "COMMUNITY_PATH", community, -1) configs = strings.Replace(configs, "COMMUNITY_PATH", community, -1)
configs = strings.Replace(configs, "PERSONAL_PATH", personal, -1) configs = strings.Replace(configs, "PERSONAL_PATH", personal, -1)
configs = strings.Replace(configs, "WORK_PATH", work, -1)
// locate and set a default pager
configs = strings.Replace(configs, "PAGER_PATH", config.Pager(), -1)
// locate and set a default editor
if editor, err := config.Editor(); err == nil {
configs = strings.Replace(configs, "EDITOR_PATH", editor, -1)
}
// comment out the community cheatpath by default, since the directory
// won't exist until the user clones it
configs = strings.Replace(configs,
" - name: community\n"+
" path: "+community+"\n"+
" tags: [ community ]\n"+
" readonly: true",
" #- name: community\n"+
" # path: "+community+"\n"+
" # tags: [ community ]\n"+
" # readonly: true",
-1,
)
// output the templated configs // output the templated configs
fmt.Println(configs) fmt.Println(configs)

View File

@@ -3,7 +3,7 @@ package main
// configs returns the default configuration template // configs returns the default configuration template
func configs() string { func configs() string {
return `--- return `---
# The editor to use with 'cheat -e <sheet>'. Defaults to $EDITOR or $VISUAL. # The editor to use with 'cheat -e <sheet>'. Overridden by $VISUAL or $EDITOR.
editor: EDITOR_PATH editor: EDITOR_PATH
# Should 'cheat' always colorize output? # Should 'cheat' always colorize output?
@@ -56,7 +56,8 @@ cheatpaths:
tags: [ work ] tags: [ work ]
readonly: false readonly: false
# Community cheatsheets are stored here by default: # Community cheatsheets (https://github.com/cheat/cheatsheets):
# To install: git clone https://github.com/cheat/cheatsheets COMMUNITY_PATH
- name: community - name: community
path: COMMUNITY_PATH path: COMMUNITY_PATH
tags: [ community ] tags: [ community ]

View File

@@ -0,0 +1,304 @@
package main
import (
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
)
// TestFirstRunIntegration exercises the end-to-end first-run experience:
// no config exists, the binary creates one, and subsequent runs succeed.
// This is the regression test for issues #721, #771, and #730.
func TestFirstRunIntegration(t *testing.T) {
// Build the cheat binary
binName := "cheat_test"
if runtime.GOOS == "windows" {
binName += ".exe"
}
binPath := filepath.Join(t.TempDir(), binName)
build := exec.Command("go", "build", "-o", binPath, ".")
if output, err := build.CombinedOutput(); err != nil {
t.Fatalf("failed to build cheat: %v\nOutput: %s", err, output)
}
t.Run("init comments out community", func(t *testing.T) {
testHome := t.TempDir()
env := firstRunEnv(testHome)
cmd := exec.Command(binPath, "--init")
cmd.Env = env
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("--init failed: %v\nOutput: %s", err, output)
}
outStr := string(output)
// No placeholder strings should survive (regression for #721)
assertNoPlaceholders(t, outStr)
// Community cheatpath should be commented out
assertCommunityCommentedOut(t, outStr)
// Personal and work cheatpaths should be active (uncommented)
assertCheatpathActive(t, outStr, "personal")
assertCheatpathActive(t, outStr, "work")
// Should include clone instructions
if !strings.Contains(outStr, "git clone") {
t.Error("expected git clone instructions in --init output")
}
// Save the config and verify it loads without errors.
// --init only outputs config, it doesn't create directories,
// so we need to create the cheatpath dirs the config references.
confpath := filepath.Join(testHome, "conf.yml")
if err := os.WriteFile(confpath, output, 0644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
// Determine the confdir that --init used (same logic as cmd_init.go)
initConfpaths := firstRunConfpaths(testHome)
initConfdir := filepath.Dir(initConfpaths[0])
for _, name := range []string{"personal", "work"} {
dir := filepath.Join(initConfdir, "cheatsheets", name)
if err := os.MkdirAll(dir, 0755); err != nil {
t.Fatalf("failed to create %s dir: %v", name, err)
}
}
cmd2 := exec.Command(binPath, "--directories")
cmd2.Env = append(append([]string{}, env...), "CHEAT_CONFIG_PATH="+confpath)
output2, err := cmd2.CombinedOutput()
if err != nil {
t.Fatalf("config from --init failed to load: %v\nOutput: %s", err, output2)
}
})
t.Run("decline config creation", func(t *testing.T) {
testHome := t.TempDir()
env := firstRunEnv(testHome)
cmd := exec.Command(binPath)
cmd.Env = env
cmd.Stdin = strings.NewReader("n\n")
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("cheat exited with error: %v\nOutput: %s", err, output)
}
// Verify no config was created
if firstRunConfigExists(testHome) {
t.Error("config file was created despite user declining")
}
})
t.Run("accept config decline community", func(t *testing.T) {
testHome := t.TempDir()
env := firstRunEnv(testHome)
// First run: yes to create config, no to community cheatsheets
cmd := exec.Command(binPath)
cmd.Env = env
cmd.Stdin = strings.NewReader("y\nn\n")
output, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("first run failed: %v\nOutput: %s", err, output)
}
outStr := string(output)
// Parse the config path from output
confpath := parseCreatedConfPath(t, outStr)
if confpath == "" {
t.Fatalf("could not find config path in output:\n%s", outStr)
}
// Verify config file exists
if _, err := os.Stat(confpath); os.IsNotExist(err) {
t.Fatalf("config file not found at %s", confpath)
}
// Verify config file contents
content, err := os.ReadFile(confpath)
if err != nil {
t.Fatalf("failed to read config: %v", err)
}
contentStr := string(content)
// No placeholder strings should survive (regression for #721)
assertNoPlaceholders(t, contentStr)
// Community cheatpath should be commented out
assertCommunityCommentedOut(t, contentStr)
// Personal and work cheatpaths should be active (uncommented)
assertCheatpathActive(t, contentStr, "personal")
assertCheatpathActive(t, contentStr, "work")
// Verify personal and work directories were created
confdir := filepath.Dir(confpath)
for _, name := range []string{"personal", "work"} {
dir := filepath.Join(confdir, "cheatsheets", name)
if _, err := os.Stat(dir); os.IsNotExist(err) {
t.Errorf("expected %s directory at %s", name, dir)
}
}
// Community directory should NOT exist
communityDir := filepath.Join(confdir, "cheatsheets", "community")
if _, err := os.Stat(communityDir); err == nil {
t.Error("community directory should not exist when declined")
}
// --- Second run: verify the config loads successfully ---
// This is the core regression test for #721/#771/#730:
// previously, the second run would fail because config.New()
// hard-errored on the missing community cheatpath directory.
// Use --directories (not --list, which exits 2 when no sheets exist).
cmd2 := exec.Command(binPath, "--directories")
cmd2.Env = append(append([]string{}, env...), "CHEAT_CONFIG_PATH="+confpath)
output2, err := cmd2.CombinedOutput()
if err != nil {
t.Fatalf(
"second run failed (regression for #721/#771/#730): %v\nOutput: %s",
err, output2,
)
}
// Verify the output lists the expected cheatpaths
outStr2 := string(output2)
if !strings.Contains(outStr2, "personal") {
t.Errorf("expected 'personal' cheatpath in --directories output:\n%s", outStr2)
}
if !strings.Contains(outStr2, "work") {
t.Errorf("expected 'work' cheatpath in --directories output:\n%s", outStr2)
}
})
}
// firstRunEnv returns a minimal environment for a clean first-run test.
func firstRunEnv(home string) []string {
env := []string{
"PATH=" + os.Getenv("PATH"),
}
switch runtime.GOOS {
case "windows":
env = append(env,
"APPDATA="+filepath.Join(home, "AppData", "Roaming"),
"USERPROFILE="+home,
"SystemRoot="+os.Getenv("SystemRoot"),
)
default:
env = append(env,
"HOME="+home,
"EDITOR=vi",
)
}
return env
}
// parseCreatedConfPath extracts the config file path from the installer's
// "Created config file: <path>" output. The message may appear mid-line
// (after prompt text), so we search for the substring anywhere in the output.
func parseCreatedConfPath(t *testing.T, output string) string {
t.Helper()
const marker = "Created config file: "
idx := strings.Index(output, marker)
if idx < 0 {
return ""
}
rest := output[idx+len(marker):]
// the path ends at the next newline
if nl := strings.IndexByte(rest, '\n'); nl >= 0 {
rest = rest[:nl]
}
return strings.TrimSpace(rest)
}
// firstRunConfpaths returns the config file paths that cheat would check
// for the given home directory, matching the logic in config.Paths().
func firstRunConfpaths(home string) []string {
switch runtime.GOOS {
case "windows":
return []string{
filepath.Join(home, "AppData", "Roaming", "cheat", "conf.yml"),
}
default:
return []string{
filepath.Join(home, ".config", "cheat", "conf.yml"),
}
}
}
// assertNoPlaceholders verifies that no template placeholder strings survived
// in the config output. This is the regression check for #721 (literal
// PAGER_PATH appearing in the config).
func assertNoPlaceholders(t *testing.T, content string) {
t.Helper()
placeholders := []string{
"PAGER_PATH",
"COMMUNITY_PATH",
"PERSONAL_PATH",
"WORK_PATH",
}
for _, p := range placeholders {
if strings.Contains(content, p) {
t.Errorf("placeholder %q was not replaced in config", p)
}
}
// EDITOR_PATH is special: it survives if no editor is found.
// In our test env EDITOR=vi is set, so it should be replaced.
if strings.Contains(content, "editor: EDITOR_PATH") {
t.Error("placeholder EDITOR_PATH was not replaced in config")
}
}
// assertCommunityCommentedOut verifies that the community cheatpath entry
// is commented out (not active) in the config.
func assertCommunityCommentedOut(t *testing.T, content string) {
t.Helper()
for _, line := range strings.Split(content, "\n") {
trimmed := strings.TrimSpace(line)
if trimmed == "- name: community" {
t.Error("community cheatpath should be commented out")
return
}
}
if !strings.Contains(content, "#- name: community") {
t.Error("expected commented-out community cheatpath")
}
}
// assertCheatpathActive verifies that a named cheatpath is present and
// uncommented in the config.
func assertCheatpathActive(t *testing.T, content string, name string) {
t.Helper()
marker := "- name: " + name
for _, line := range strings.Split(content, "\n") {
trimmed := strings.TrimSpace(line)
if trimmed == marker {
return
}
}
t.Errorf("expected active (uncommented) cheatpath %q", name)
}
// firstRunConfigExists checks whether a cheat config file exists under the
// given home directory at any of the standard locations.
func firstRunConfigExists(home string) bool {
candidates := []string{
filepath.Join(home, ".config", "cheat", "conf.yml"),
filepath.Join(home, ".cheat", "conf.yml"),
filepath.Join(home, "AppData", "Roaming", "cheat", "conf.yml"),
}
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return true
}
}
return false
}

View File

@@ -15,7 +15,7 @@ import (
"github.com/cheat/cheat/internal/installer" "github.com/cheat/cheat/internal/installer"
) )
const version = "4.5.0" const version = "4.5.1"
func main() { func main() {

View File

@@ -5,6 +5,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"testing" "testing"
) )
@@ -12,6 +13,10 @@ import (
// TestPathTraversalIntegration tests that the cheat binary properly blocks // TestPathTraversalIntegration tests that the cheat binary properly blocks
// path traversal attempts when invoked as a subprocess. // path traversal attempts when invoked as a subprocess.
func TestPathTraversalIntegration(t *testing.T) { func TestPathTraversalIntegration(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("integration test uses Unix-specific env and tools")
}
// Build the cheat binary // Build the cheat binary
binPath := filepath.Join(t.TempDir(), "cheat_test") binPath := filepath.Join(t.TempDir(), "cheat_test")
if output, err := exec.Command("go", "build", "-o", binPath, ".").CombinedOutput(); err != nil { if output, err := exec.Command("go", "build", "-o", binPath, ".").CombinedOutput(); err != nil {
@@ -146,6 +151,10 @@ cheatpaths:
// TestPathTraversalRealWorld tests with more realistic scenarios // TestPathTraversalRealWorld tests with more realistic scenarios
func TestPathTraversalRealWorld(t *testing.T) { func TestPathTraversalRealWorld(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("integration test uses Unix-specific env and tools")
}
// This test ensures our protection works with actual file operations // This test ensures our protection works with actual file operations
// Build cheat // Build cheat

View File

@@ -1,38 +0,0 @@
# Architecture Decision Records
This directory contains Architecture Decision Records (ADRs) for the cheat project.
## What is an ADR?
An Architecture Decision Record captures an important architectural decision made along with its context and consequences. ADRs help us:
- Document why decisions were made
- Understand the context and trade-offs
- Review decisions when requirements change
- Onboard new contributors
## ADR Format
Each ADR follows this template:
1. **Title**: ADR-NNN: Brief description
2. **Date**: When the decision was made
3. **Status**: Proposed, Accepted, Deprecated, Superseded
4. **Context**: What prompted this decision?
5. **Decision**: What did we decide to do?
6. **Consequences**: What are the positive, negative, and neutral outcomes?
## Index of ADRs
| ADR | Title | Status | Date |
|-----|-------|--------|------|
| [001](001-path-traversal-protection.md) | Path Traversal Protection for Cheatsheet Names | Accepted | 2025-01-21 |
| [002](002-environment-variable-parsing.md) | No Defensive Checks for Environment Variable Parsing | Accepted | 2025-01-21 |
| [003](003-search-parallelization.md) | No Parallelization for Search Operations | Accepted | 2025-01-22 |
## Creating a New ADR
1. Copy the template from an existing ADR
2. Use the next sequential number
3. Fill in all sections
4. Include the ADR alongside the commit implementing the decision

39
go.mod
View File

@@ -1,38 +1,37 @@
module github.com/cheat/cheat module github.com/cheat/cheat
go 1.19 go 1.26
require ( require (
github.com/alecthomas/chroma/v2 v2.12.0 github.com/alecthomas/chroma/v2 v2.23.1
github.com/davecgh/go-spew v1.1.1 github.com/davecgh/go-spew v1.1.1
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815
github.com/go-git/go-git/v5 v5.11.0 github.com/go-git/go-git/v5 v5.16.5
github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-isatty v0.0.20
github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/go-homedir v1.1.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
dario.cat/mergo v1.0.0 // indirect dario.cat/mergo v1.0.2 // indirect
github.com/Microsoft/go-winio v0.6.1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v0.0.0-20230923063757-afb1ddc0824c // indirect github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/cloudflare/circl v1.3.7 // indirect github.com/cloudflare/circl v1.6.3 // indirect
github.com/cyphar/filepath-securejoin v0.2.4 // indirect github.com/cyphar/filepath-securejoin v0.6.1 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/emirpasic/gods v1.18.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
github.com/go-git/go-billy/v5 v5.5.0 // indirect github.com/go-git/go-billy/v5 v5.7.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/kevinburke/ssh_config v1.5.0 // indirect
github.com/pjbgf/sha1cd v0.3.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/sergi/go-diff v1.3.1 // indirect github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/skeema/knownhosts v1.2.1 // indirect github.com/sergi/go-diff v1.4.0 // indirect
github.com/skeema/knownhosts v1.3.2 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect
golang.org/x/crypto v0.17.0 // indirect golang.org/x/crypto v0.48.0 // indirect
golang.org/x/mod v0.14.0 // indirect golang.org/x/net v0.50.0 // indirect
golang.org/x/net v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/tools v0.16.1 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect
) )

152
go.sum
View File

@@ -1,140 +1,118 @@
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/ProtonMail/go-crypto v0.0.0-20230923063757-afb1ddc0824c h1:kMFnB0vCcX7IL/m9Y5LO+KQYv+t1CQOiFe6+SV2J7bE= github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
github.com/ProtonMail/go-crypto v0.0.0-20230923063757-afb1ddc0824c/go.mod h1:EjAoLdwvbIOoOQr3ihjnSoLZRtE8azugULFRteWMNc0= github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
github.com/alecthomas/assert/v2 v2.2.1 h1:XivOgYcduV98QCahG8T5XTezV5bylXe+lBxLG2K2ink= github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
github.com/alecthomas/chroma/v2 v2.12.0 h1:Wh8qLEgMMsN7mgyG8/qIpegky2Hvzr4By6gEF7cmWgw= github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.12.0/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw= github.com/alecthomas/chroma/v2 v2.23.1 h1:nv2AVZdTyClGbVQkIzlDm/rnhk1E9bU9nXwmZ/Vk/iY=
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk= github.com/alecthomas/chroma/v2 v2.23.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8=
github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4=
github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= github.com/cyphar/filepath-securejoin v0.6.1 h1:5CeZ1jPXEiYt3+Z6zqprSAgSWiggmpVyciv8syjIpVE=
github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= github.com/cyphar/filepath-securejoin v0.6.1/go.mod h1:A8hd4EnAeyujCJRrICiOWqjS1AX0a9kM5XL+NwKoYSc=
github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 h1:bWDMxwH3px2JBh6AyO7hdCn/PkvCZXii8TGj7sbtEbQ= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 h1:bWDMxwH3px2JBh6AyO7hdCn/PkvCZXii8TGj7sbtEbQ=
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI=
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic=
github.com/go-git/go-billy/v5 v5.5.0 h1:yEY4yhzCDuMGSv83oGxiBotRzhwhNr8VZyphhiu+mTU= github.com/go-git/go-billy/v5 v5.7.0 h1:83lBUJhGWhYp0ngzCMSgllhUSuoHP1iEWYjsPl9nwqM=
github.com/go-git/go-billy/v5 v5.5.0/go.mod h1:hmexnoNsr2SJU1Ju67OaNz5ASJY3+sHgFRpCtpDCKow= github.com/go-git/go-billy/v5 v5.7.0/go.mod h1:/1IUejTKH8xipsAcdfcSAlUlo2J7lkYV8GTKxAT/L3E=
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4=
github.com/go-git/go-git/v5 v5.11.0 h1:XIZc1p+8YzypNr34itUfSvYJcv+eYdTnTvOZ2vD3cA4= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII=
github.com/go-git/go-git/v5 v5.11.0/go.mod h1:6GFcX2P3NM7FPBfpePbpLd21XxsgdAt+lKqXmCUiUCY= github.com/go-git/go-git/v5 v5.16.5 h1:mdkuqblwr57kVfXri5TTH+nMFLNUxIj9Z7F5ykFbw5s=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/go-git/go-git/v5 v5.16.5/go.mod h1:QOMLpNf1qxuSY4StA/ArOdfFR2TrKEjJiye2kel2m+M=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A=
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= github.com/kevinburke/ssh_config v1.5.0 h1:3cPZmE54xb5j3G5xQCjSvokqNwU2uW+3ry1+PRLSPpA=
github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kevinburke/ssh_config v1.5.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/skeema/knownhosts v1.2.1 h1:SHWdIUa82uGZz+F+47k8SY4QhhI291cXCpopT1lK2AQ= github.com/skeema/knownhosts v1.3.2 h1:EDL9mgf4NzwMXCTfaxSD/o/a5fxDw/xL9nkU28JjdBg=
github.com/skeema/knownhosts v1.2.1/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= github.com/skeema/knownhosts v1.3.2/go.mod h1:bEg3iQAuw+jyiw+484wwFJoKSLwcfd7fqRy+N0QTiow=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA=
golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@@ -1,6 +1,7 @@
package cheatpath package cheatpath
import ( import (
"runtime"
"strings" "strings"
"testing" "testing"
) )
@@ -53,9 +54,15 @@ func TestValidateSheetName(t *testing.T) {
errMsg: "'..'", errMsg: "'..'",
}, },
{ {
name: "absolute path", name: "absolute path unix",
input: "/etc/passwd", input: "/etc/passwd",
wantErr: true, wantErr: runtime.GOOS != "windows", // /etc/passwd is not absolute on Windows
errMsg: "absolute",
},
{
name: "absolute path windows",
input: `C:\evil`,
wantErr: runtime.GOOS == "windows", // C:\evil is not absolute on Unix
errMsg: "absolute", errMsg: "absolute",
}, },
{ {

View File

@@ -64,7 +64,8 @@ func New(_ map[string]interface{}, confPath string, resolve bool) (Config, error
} }
// process cheatpaths // process cheatpaths
for i, cheatpath := range conf.Cheatpaths { var validPaths []cp.Cheatpath
for _, cheatpath := range conf.Cheatpaths {
// expand ~ in config paths // expand ~ in config paths
expanded, err := homedir.Expand(cheatpath.Path) expanded, err := homedir.Expand(cheatpath.Path)
@@ -83,6 +84,14 @@ func New(_ map[string]interface{}, confPath string, resolve bool) (Config, error
if resolve { if resolve {
evaled, err := filepath.EvalSymlinks(expanded) evaled, err := filepath.EvalSymlinks(expanded)
if err != nil { if err != nil {
// if the path simply doesn't exist, warn and skip it
if os.IsNotExist(err) {
fmt.Fprintf(os.Stderr,
"WARNING: cheatpath '%s' does not exist, skipping\n",
expanded,
)
continue
}
return Config{}, fmt.Errorf( return Config{}, fmt.Errorf(
"failed to resolve symlink: %s: %v", "failed to resolve symlink: %s: %v",
expanded, expanded,
@@ -93,13 +102,22 @@ func New(_ map[string]interface{}, confPath string, resolve bool) (Config, error
expanded = evaled expanded = evaled
} }
conf.Cheatpaths[i].Path = expanded cheatpath.Path = expanded
validPaths = append(validPaths, cheatpath)
}
conf.Cheatpaths = validPaths
// determine the editor: env vars override the config file value,
// following standard Unix convention (see #589)
if v := os.Getenv("VISUAL"); v != "" {
conf.Editor = v
} else if v := os.Getenv("EDITOR"); v != "" {
conf.Editor = v
} else {
conf.Editor = strings.TrimSpace(conf.Editor)
} }
// trim editor whitespace // if an editor was still not determined, attempt to choose one
conf.Editor = strings.TrimSpace(conf.Editor)
// if an editor was not provided in the configs, attempt to choose one
// that's appropriate for the environment // that's appropriate for the environment
if conf.Editor == "" { if conf.Editor == "" {
if conf.Editor, err = Editor(); err != nil { if conf.Editor, err = Editor(); err != nil {

View File

@@ -3,6 +3,7 @@ package config
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"testing" "testing"
"github.com/cheat/cheat/internal/mock" "github.com/cheat/cheat/internal/mock"
@@ -39,6 +40,12 @@ func TestConfigLocalCheatpath(t *testing.T) {
} }
defer os.RemoveAll(tempDir) defer os.RemoveAll(tempDir)
// Resolve symlinks in temp dir path (macOS /var -> /private/var)
tempDir, err = filepath.EvalSymlinks(tempDir)
if err != nil {
t.Fatalf("failed to resolve temp dir symlinks: %v", err)
}
// Save current working directory // Save current working directory
oldCwd, err := os.Getwd() oldCwd, err := os.Getwd()
if err != nil { if err != nil {
@@ -106,6 +113,12 @@ func TestConfigSymlinkResolution(t *testing.T) {
} }
defer os.RemoveAll(tempDir) defer os.RemoveAll(tempDir)
// Resolve symlinks in temp dir path (macOS /var -> /private/var)
tempDir, err = filepath.EvalSymlinks(tempDir)
if err != nil {
t.Fatalf("failed to resolve temp dir symlinks: %v", err)
}
// Create target directory // Create target directory
targetDir := filepath.Join(tempDir, "target") targetDir := filepath.Join(tempDir, "target")
err = os.Mkdir(targetDir, 0755) err = os.Mkdir(targetDir, 0755)
@@ -176,10 +189,14 @@ cheatpaths:
t.Fatalf("failed to write config: %v", err) t.Fatalf("failed to write config: %v", err)
} }
// Load config with symlink resolution should fail // Load config with symlink resolution should skip the broken cheatpath
_, err = New(map[string]interface{}{}, configFile, true) // (warn to stderr) rather than hard-error
if err == nil { conf, err := New(map[string]interface{}{}, configFile, true)
t.Error("expected error for broken symlink, got nil") if err != nil {
t.Errorf("expected no error for broken symlink (should skip), got: %v", err)
}
if len(conf.Cheatpaths) != 0 {
t.Errorf("expected broken cheatpath to be filtered out, got %d cheatpaths", len(conf.Cheatpaths))
} }
} }
@@ -214,6 +231,10 @@ cheatpaths:
// TestConfigGetCwdError tests error handling when os.Getwd fails // TestConfigGetCwdError tests error handling when os.Getwd fails
func TestConfigGetCwdError(t *testing.T) { func TestConfigGetCwdError(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Windows does not allow removing the current directory")
}
// This is difficult to test without being able to break os.Getwd // This is difficult to test without being able to break os.Getwd
// We'll create a scenario where the current directory is removed // We'll create a scenario where the current directory is removed

View File

@@ -16,6 +16,16 @@ import (
// TestConfig asserts that the configs are loaded correctly // TestConfig asserts that the configs are loaded correctly
func TestConfigSuccessful(t *testing.T) { func TestConfigSuccessful(t *testing.T) {
// clear env vars so they don't override the config file value
oldVisual := os.Getenv("VISUAL")
oldEditor := os.Getenv("EDITOR")
os.Unsetenv("VISUAL")
os.Unsetenv("EDITOR")
defer func() {
os.Setenv("VISUAL", oldVisual)
os.Setenv("EDITOR", oldEditor)
}()
// initialize a config // initialize a config
conf, err := New(map[string]interface{}{}, mock.Path("conf/conf.yml"), false) conf, err := New(map[string]interface{}{}, mock.Path("conf/conf.yml"), false)
if err != nil { if err != nil {
@@ -75,37 +85,78 @@ func TestConfigFailure(t *testing.T) {
} }
} }
// TestEmptyEditor asserts that envvars are respected if an editor is not // TestEditorEnvOverride asserts that $VISUAL and $EDITOR override the
// specified in the configs // config file value at runtime (regression test for #589)
func TestEmptyEditor(t *testing.T) { func TestEditorEnvOverride(t *testing.T) {
// save and clear the environment variables
oldVisual := os.Getenv("VISUAL")
oldEditor := os.Getenv("EDITOR")
defer func() {
os.Setenv("VISUAL", oldVisual)
os.Setenv("EDITOR", oldEditor)
}()
// clear the environment variables // with no env vars, the config file value should be used
os.Setenv("VISUAL", "") os.Unsetenv("VISUAL")
os.Setenv("EDITOR", "") os.Unsetenv("EDITOR")
conf, err := New(map[string]interface{}{}, mock.Path("conf/conf.yml"), false)
if err != nil {
t.Fatalf("failed to init configs: %v", err)
}
if conf.Editor != "vim" {
t.Errorf("expected config file editor: want: vim, got: %s", conf.Editor)
}
// initialize a config // $EDITOR should override the config file value
os.Setenv("EDITOR", "nano")
conf, err = New(map[string]interface{}{}, mock.Path("conf/conf.yml"), false)
if err != nil {
t.Fatalf("failed to init configs: %v", err)
}
if conf.Editor != "nano" {
t.Errorf("$EDITOR should override config: want: nano, got: %s", conf.Editor)
}
// $VISUAL should override both $EDITOR and the config file value
os.Setenv("VISUAL", "emacs")
conf, err = New(map[string]interface{}{}, mock.Path("conf/conf.yml"), false)
if err != nil {
t.Fatalf("failed to init configs: %v", err)
}
if conf.Editor != "emacs" {
t.Errorf("$VISUAL should override all: want: emacs, got: %s", conf.Editor)
}
}
// TestEditorEnvFallback asserts that env vars are used as fallback when
// no editor is specified in the config file
func TestEditorEnvFallback(t *testing.T) {
// save and clear the environment variables
oldVisual := os.Getenv("VISUAL")
oldEditor := os.Getenv("EDITOR")
defer func() {
os.Setenv("VISUAL", oldVisual)
os.Setenv("EDITOR", oldEditor)
}()
// set $EDITOR and assert it's used when config has no editor
os.Unsetenv("VISUAL")
os.Setenv("EDITOR", "foo")
conf, err := New(map[string]interface{}{}, mock.Path("conf/empty.yml"), false) conf, err := New(map[string]interface{}{}, mock.Path("conf/empty.yml"), false)
if err != nil { if err != nil {
t.Errorf("failed to initialize test: %v", err) t.Fatalf("failed to init configs: %v", err)
}
// set editor, and assert that it is respected
os.Setenv("EDITOR", "foo")
conf, err = New(map[string]interface{}{}, mock.Path("conf/empty.yml"), false)
if err != nil {
t.Errorf("failed to init configs: %v", err)
} }
if conf.Editor != "foo" { if conf.Editor != "foo" {
t.Errorf("failed to respect editor: want: foo, got: %s", conf.Editor) t.Errorf("failed to respect $EDITOR: want: foo, got: %s", conf.Editor)
} }
// set visual, and assert that it overrides editor // set $VISUAL and assert it takes precedence over $EDITOR
os.Setenv("VISUAL", "bar") os.Setenv("VISUAL", "bar")
conf, err = New(map[string]interface{}{}, mock.Path("conf/empty.yml"), false) conf, err = New(map[string]interface{}{}, mock.Path("conf/empty.yml"), false)
if err != nil { if err != nil {
t.Errorf("failed to init configs: %v", err) t.Fatalf("failed to init configs: %v", err)
} }
if conf.Editor != "bar" { if conf.Editor != "bar" {
t.Errorf("failed to respect editor: want: bar, got: %s", conf.Editor) t.Errorf("failed to respect $VISUAL: want: bar, got: %s", conf.Editor)
} }
} }

View File

@@ -3,6 +3,7 @@ package config
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"testing" "testing"
) )
@@ -74,12 +75,18 @@ func TestInitCreateDirectory(t *testing.T) {
// TestInitWriteError tests error handling when file write fails // TestInitWriteError tests error handling when file write fails
func TestInitWriteError(t *testing.T) { func TestInitWriteError(t *testing.T) {
// Skip this test if running as root (can write anywhere) // Skip this test if running as root (can write anywhere)
if os.Getuid() == 0 { if runtime.GOOS != "windows" && os.Getuid() == 0 {
t.Skip("Cannot test write errors as root") t.Skip("Cannot test write errors as root")
} }
// Use a platform-appropriate invalid path
invalidPath := "/dev/null/impossible/path/conf.yml"
if runtime.GOOS == "windows" {
invalidPath = `NUL\impossible\path\conf.yml`
}
// Try to write to a read-only directory // Try to write to a read-only directory
err := Init("/dev/null/impossible/path/conf.yml", "test") err := Init(invalidPath, "test")
if err == nil { if err == nil {
t.Error("expected error when writing to invalid path, got nil") t.Error("expected error when writing to invalid path, got nil")
} }

View File

@@ -7,6 +7,16 @@ import (
) )
func TestNewTrimsWhitespace(t *testing.T) { func TestNewTrimsWhitespace(t *testing.T) {
// clear env vars so they don't override the config file value
oldVisual := os.Getenv("VISUAL")
oldEditor := os.Getenv("EDITOR")
os.Unsetenv("VISUAL")
os.Unsetenv("EDITOR")
defer func() {
os.Setenv("VISUAL", oldVisual)
os.Setenv("EDITOR", oldEditor)
}()
// Create a temporary config file with whitespace in editor and pager // Create a temporary config file with whitespace in editor and pager
tmpDir := t.TempDir() tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yml") configPath := filepath.Join(tmpDir, "config.yml")

View File

@@ -1,7 +1,9 @@
package config package config
import ( import (
"path/filepath"
"reflect" "reflect"
"runtime"
"testing" "testing"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
@@ -10,6 +12,9 @@ import (
// TestValidatePathsNix asserts that the proper config paths are returned on // TestValidatePathsNix asserts that the proper config paths are returned on
// *nix platforms // *nix platforms
func TestValidatePathsNix(t *testing.T) { func TestValidatePathsNix(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("filepath.Join uses backslashes on Windows")
}
// mock the user's home directory // mock the user's home directory
home := "/home/foo" home := "/home/foo"
@@ -57,6 +62,9 @@ func TestValidatePathsNix(t *testing.T) {
// TestValidatePathsNixNoXDG asserts that the proper config paths are returned // TestValidatePathsNixNoXDG asserts that the proper config paths are returned
// on *nix platforms when `XDG_CONFIG_HOME is not set // on *nix platforms when `XDG_CONFIG_HOME is not set
func TestValidatePathsNixNoXDG(t *testing.T) { func TestValidatePathsNixNoXDG(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("filepath.Join uses backslashes on Windows")
}
// mock the user's home directory // mock the user's home directory
home := "/home/foo" home := "/home/foo"
@@ -106,8 +114,8 @@ func TestValidatePathsWindows(t *testing.T) {
// mock some envvars // mock some envvars
envvars := map[string]string{ envvars := map[string]string{
"APPDATA": "/apps", "APPDATA": filepath.Join("C:", "apps"),
"PROGRAMDATA": "/programs", "PROGRAMDATA": filepath.Join("C:", "programs"),
} }
// get the paths for the platform // get the paths for the platform
@@ -118,8 +126,8 @@ func TestValidatePathsWindows(t *testing.T) {
// specify the expected output // specify the expected output
want := []string{ want := []string{
"/apps/cheat/conf.yml", filepath.Join("C:", "apps", "cheat", "conf.yml"),
"/programs/cheat/conf.yml", filepath.Join("C:", "programs", "cheat", "conf.yml"),
} }
// assert that output matches expectations // assert that output matches expectations

View File

@@ -3,7 +3,6 @@
package installer package installer
import ( import (
"bufio"
"fmt" "fmt"
"os" "os"
"strings" "strings"
@@ -12,20 +11,34 @@ import (
// Prompt prompts the user for a answer // Prompt prompts the user for a answer
func Prompt(prompt string, def bool) (bool, error) { func Prompt(prompt string, def bool) (bool, error) {
// initialize a line reader
reader := bufio.NewReader(os.Stdin)
// display the prompt // display the prompt
fmt.Printf("%s: ", prompt) fmt.Printf("%s: ", prompt)
// read the answer // read one byte at a time until newline to avoid buffering past the
ans, err := reader.ReadString('\n') // end of the current line, which would consume input intended for
if err != nil { // subsequent Prompt calls on the same stdin
return false, fmt.Errorf("failed to parse input: %v", err) var line []byte
buf := make([]byte, 1)
for {
n, err := os.Stdin.Read(buf)
if n > 0 {
if buf[0] == '\n' {
break
}
if buf[0] != '\r' {
line = append(line, buf[0])
}
}
if err != nil {
if len(line) > 0 {
break
}
return false, fmt.Errorf("failed to prompt: %v", err)
}
} }
// normalize the answer // normalize the answer
ans = strings.ToLower(strings.TrimSpace(ans)) ans := strings.ToLower(strings.TrimSpace(string(line)))
// return the appropriate response // return the appropriate response
switch ans { switch ans {

View File

@@ -154,8 +154,8 @@ func TestPromptError(t *testing.T) {
if err == nil { if err == nil {
t.Error("expected error when reading from closed stdin, got nil") t.Error("expected error when reading from closed stdin, got nil")
} }
if !strings.Contains(err.Error(), "failed to parse input") { if !strings.Contains(err.Error(), "failed to prompt") {
t.Errorf("expected 'failed to parse input' error, got: %v", err) t.Errorf("expected 'failed to prompt' error, got: %v", err)
} }
} }

View File

@@ -17,13 +17,15 @@ func Run(configs string, confpath string) error {
// cheatsheets based on the user's platform // cheatsheets based on the user's platform
confdir := filepath.Dir(confpath) confdir := filepath.Dir(confpath)
// create paths for community and personal cheatsheets // create paths for community, personal, and work cheatsheets
community := filepath.Join(confdir, "cheatsheets", "community") community := filepath.Join(confdir, "cheatsheets", "community")
personal := filepath.Join(confdir, "cheatsheets", "personal") personal := filepath.Join(confdir, "cheatsheets", "personal")
work := filepath.Join(confdir, "cheatsheets", "work")
// set default cheatpaths // set default cheatpaths
configs = strings.Replace(configs, "COMMUNITY_PATH", community, -1) configs = strings.Replace(configs, "COMMUNITY_PATH", community, -1)
configs = strings.Replace(configs, "PERSONAL_PATH", personal, -1) configs = strings.Replace(configs, "PERSONAL_PATH", personal, -1)
configs = strings.Replace(configs, "WORK_PATH", work, -1)
// locate and set a default pager // locate and set a default pager
configs = strings.Replace(configs, "PAGER_PATH", config.Pager(), -1) configs = strings.Replace(configs, "PAGER_PATH", config.Pager(), -1)
@@ -44,15 +46,29 @@ func Run(configs string, confpath string) error {
// clone the community cheatsheets if so instructed // clone the community cheatsheets if so instructed
if yes { if yes {
// clone the community cheatsheets
fmt.Printf("Cloning community cheatsheets to %s.\n", community) fmt.Printf("Cloning community cheatsheets to %s.\n", community)
if err := repo.Clone(community); err != nil { if err := repo.Clone(community); err != nil {
return fmt.Errorf("failed to clone cheatsheets: %v", err) return fmt.Errorf("failed to clone cheatsheets: %v", err)
} }
} else {
// comment out the community cheatpath in the config since
// the directory won't exist
configs = strings.Replace(configs,
" - name: community\n"+
" path: "+community+"\n"+
" tags: [ community ]\n"+
" readonly: true",
" #- name: community\n"+
" # path: "+community+"\n"+
" # tags: [ community ]\n"+
" # readonly: true",
-1,
)
}
// also create a directory for personal cheatsheets // always create personal and work directories
fmt.Printf("Cloning personal cheatsheets to %s.\n", personal) for _, dir := range []string{personal, work} {
if err := os.MkdirAll(personal, os.ModePerm); err != nil { if err := os.MkdirAll(dir, os.ModePerm); err != nil {
return fmt.Errorf("failed to create directory: %v", err) return fmt.Errorf("failed to create directory: %v", err)
} }
} }

View File

@@ -5,6 +5,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"testing" "testing"
) )
@@ -53,8 +54,8 @@ cheatpaths:
confpath: filepath.Join(tempDir, "conf1", "conf.yml"), confpath: filepath.Join(tempDir, "conf1", "conf.yml"),
userInput: "n\n", userInput: "n\n",
wantErr: false, wantErr: false,
checkFiles: []string{"conf1/conf.yml"}, checkFiles: []string{"conf1/conf.yml", "conf1/cheatsheets/personal", "conf1/cheatsheets/work"},
dontWantFiles: []string{"conf1/cheatsheets/community", "conf1/cheatsheets/personal"}, dontWantFiles: []string{"conf1/cheatsheets/community"},
}, },
{ {
name: "user accepts but clone fails", name: "user accepts but clone fails",
@@ -69,15 +70,33 @@ cheatpaths:
wantInErr: "failed to clone cheatsheets", wantInErr: "failed to clone cheatsheets",
}, },
{ {
name: "invalid config path", name: "invalid config path",
configs: "test", configs: "test",
confpath: "/nonexistent/path/conf.yml", // /dev/null/... is truly uncreatable on Unix;
// NUL\... is uncreatable on Windows
confpath: func() string {
if runtime.GOOS == "windows" {
return `NUL\impossible\conf.yml`
}
return "/dev/null/impossible/conf.yml"
}(),
userInput: "n\n", userInput: "n\n",
wantErr: true, wantErr: true,
wantInErr: "failed to create config file", wantInErr: "failed to create",
}, },
} }
// Pre-create a .git dir inside the community path so go-git's PlainClone
// returns ErrRepositoryAlreadyExists (otherwise, on CI runners with
// network access, the real clone succeeds and the test fails)
fakeGitDir := filepath.Join(tempDir, "conf2", "cheatsheets", "community", ".git")
if err := os.MkdirAll(fakeGitDir, 0755); err != nil {
t.Fatalf("failed to create fake .git dir: %v", err)
}
if err := os.WriteFile(filepath.Join(fakeGitDir, "HEAD"), []byte("ref: refs/heads/main\n"), 0644); err != nil {
t.Fatalf("failed to write fake HEAD: %v", err)
}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Create stdin pipe // Create stdin pipe
@@ -158,10 +177,18 @@ func TestRunStringReplacements(t *testing.T) {
editor: EDITOR_PATH editor: EDITOR_PATH
pager: PAGER_PATH pager: PAGER_PATH
cheatpaths: cheatpaths:
- name: community
path: COMMUNITY_PATH
- name: personal - name: personal
path: PERSONAL_PATH path: PERSONAL_PATH
tags: [ personal ]
readonly: false
- name: work
path: WORK_PATH
tags: [ work ]
readonly: false
- name: community
path: COMMUNITY_PATH
tags: [ community ]
readonly: true
` `
// Create temp directory // Create temp directory
@@ -175,7 +202,6 @@ cheatpaths:
confdir := filepath.Dir(confpath) confdir := filepath.Dir(confpath)
// Expected paths // Expected paths
expectedCommunity := filepath.Join(confdir, "cheatsheets", "community")
expectedPersonal := filepath.Join(confdir, "cheatsheets", "personal") expectedPersonal := filepath.Join(confdir, "cheatsheets", "personal")
// Save original stdin/stdout // Save original stdin/stdout
@@ -225,10 +251,16 @@ cheatpaths:
if strings.Contains(contentStr, "PAGER_PATH") && !strings.Contains(contentStr, fmt.Sprintf("pager: %s", "")) { if strings.Contains(contentStr, "PAGER_PATH") && !strings.Contains(contentStr, fmt.Sprintf("pager: %s", "")) {
t.Error("PAGER_PATH was not replaced") t.Error("PAGER_PATH was not replaced")
} }
if strings.Contains(contentStr, "WORK_PATH") {
t.Error("WORK_PATH was not replaced")
}
// Verify correct paths were used // Verify community path is commented out (user declined)
if !strings.Contains(contentStr, expectedCommunity) { if strings.Contains(contentStr, " - name: community") {
t.Errorf("expected community path %q in config", expectedCommunity) t.Error("expected community cheatpath to be commented out when declined")
}
if !strings.Contains(contentStr, " #- name: community") {
t.Error("expected commented-out community cheatpath")
} }
if !strings.Contains(contentStr, expectedPersonal) { if !strings.Contains(contentStr, expectedPersonal) {
t.Errorf("expected personal path %q in config", expectedPersonal) t.Errorf("expected personal path %q in config", expectedPersonal)

View File

@@ -3,7 +3,6 @@ package mock
import ( import (
"fmt" "fmt"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
) )
@@ -16,7 +15,7 @@ func Path(filename string) string {
// compute the mock path // compute the mock path
file, err := filepath.Abs( file, err := filepath.Abs(
path.Join( filepath.Join(
filepath.Dir(thisfile), filepath.Dir(thisfile),
"../../mocks", "../../mocks",
filename, filename,

View File

@@ -3,6 +3,7 @@ package repo
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"testing" "testing"
) )
@@ -12,6 +13,9 @@ func TestClone(t *testing.T) {
// that don't require actual cloning // that don't require actual cloning
t.Run("clone to read-only directory", func(t *testing.T) { t.Run("clone to read-only directory", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("chmod does not restrict writes on Windows")
}
if os.Getuid() == 0 { if os.Getuid() == 0 {
t.Skip("Cannot test read-only directory as root") t.Skip("Cannot test read-only directory as root")
} }

View File

@@ -6,6 +6,11 @@ import (
"strings" "strings"
) )
// gitSep is the `.git` path component surrounded by path separators.
// Used to match `.git` as a complete path component, not as a suffix
// of a directory name (e.g., `personal.git`).
var gitSep = string(os.PathSeparator) + ".git" + string(os.PathSeparator)
// GitDir returns `true` if we are iterating over a directory contained within // GitDir returns `true` if we are iterating over a directory contained within
// a repositories `.git` directory. // a repositories `.git` directory.
func GitDir(path string) (bool, error) { func GitDir(path string) (bool, error) {
@@ -50,9 +55,20 @@ func GitDir(path string) (bool, error) {
See: https://github.com/cheat/cheat/issues/699 See: https://github.com/cheat/cheat/issues/699
Accounting for all of the above (hopefully?), the current solution is Accounting for all of the above, the next solution was to search not
not to search for `.git`, but `.git/` (including the directory for `.git`, but `.git/` (including the directory separator), and then
separator), and then only ceasing to walk the directory on a match. only ceasing to walk the directory on a match.
This, however, also had a bug: searching for `.git/` also matched
directory names that *ended with* `.git`, like `personal.git/`. This
caused cheatsheets stored under such paths to be silently skipped.
See: https://github.com/cheat/cheat/issues/711
The current (and hopefully final) solution requires the path separator
on *both* sides of `.git`, i.e., searching for `/.git/`. This ensures
that `.git` is matched only as a complete path component, not as a
suffix of a directory name.
To summarize, this code must account for the following possibilities: To summarize, this code must account for the following possibilities:
@@ -61,17 +77,16 @@ func GitDir(path string) (bool, error) {
3. A cheatpath is a repository, and contains a `.git*` file 3. A cheatpath is a repository, and contains a `.git*` file
4. A cheatpath is a submodule 4. A cheatpath is a submodule
5. A cheatpath is a hidden directory 5. A cheatpath is a hidden directory
6. A cheatpath is inside a directory whose name ends with `.git`
Care must be taken to support the above on both Unix and Windows Care must be taken to support the above on both Unix and Windows
systems, which have different directory separators and line-endings. systems, which have different directory separators and line-endings.
There is a lot of nuance to all of this, and it would be worthwhile to NB: `filepath.Walk` always passes absolute paths to the walk function,
do two things to stop writing bugs here: so `.git` will never appear as the first path component. This is what
makes the "separator on both sides" approach safe.
1. Build integration tests around all of this A reasonable smoke-test for ensuring that skipping is being applied
2. Discard string-matching solutions entirely, and use `go-git` instead
NB: A reasonable smoke-test for ensuring that skipping is being applied
correctly is to run the following command: correctly is to run the following command:
make && strace ./dist/cheat -l | wc -l make && strace ./dist/cheat -l | wc -l
@@ -83,8 +98,8 @@ func GitDir(path string) (bool, error) {
of syscalls should be significantly lower with the skip check enabled. of syscalls should be significantly lower with the skip check enabled.
*/ */
// determine if the literal string `.git` appears within `path` // determine if `.git` appears as a complete path component
pos := strings.Index(path, fmt.Sprintf(".git%s", string(os.PathSeparator))) pos := strings.Index(path, gitSep)
// if it does not, we know for certain that we are not within a `.git` // if it does not, we know for certain that we are not within a `.git`
// directory. // directory.

View File

@@ -1,137 +1,191 @@
package repo package repo
import ( import (
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
) )
func TestGitDir(t *testing.T) { // setupGitDirTestTree creates a temporary directory structure that exercises
// Create a temporary directory for testing // every case documented in GitDir's comment block. The caller must defer
tempDir, err := os.MkdirTemp("", "cheat-test-*") // os.RemoveAll on the returned root.
if err != nil { //
t.Fatalf("failed to create temp dir: %v", err) // Layout:
} //
defer os.RemoveAll(tempDir) // root/
// ├── plain/ # not a repository
// │ └── sheet
// ├── repo/ # a repository (.git is a directory)
// │ ├── .git/
// │ │ ├── HEAD
// │ │ ├── objects/
// │ │ │ └── pack/
// │ │ └── refs/
// │ │ └── heads/
// │ ├── .gitignore
// │ ├── .gitattributes
// │ └── sheet
// ├── submodule/ # a submodule (.git is a file)
// │ ├── .git # file, not directory
// │ └── sheet
// ├── dotgit-suffix.git/ # directory name ends in .git (#711)
// │ └── cheat/
// │ └── sheet
// ├── dotgit-mid.git/ # .git suffix mid-path (#711)
// │ └── nested/
// │ └── sheet
// ├── .github/ # .github directory (not .git)
// │ └── workflows/
// │ └── ci.yml
// └── .hidden/ # generic hidden directory
// └── sheet
func setupGitDirTestTree(t *testing.T) string {
t.Helper()
// Create test directory structure root := t.TempDir()
testDirs := []string{
filepath.Join(tempDir, ".git"), dirs := []string{
filepath.Join(tempDir, ".git", "objects"), // case 1: not a repository
filepath.Join(tempDir, ".git", "refs"), filepath.Join(root, "plain"),
filepath.Join(tempDir, "regular"),
filepath.Join(tempDir, "regular", ".git"), // case 2: a repository (.git directory with contents)
filepath.Join(tempDir, "submodule"), filepath.Join(root, "repo", ".git", "objects", "pack"),
filepath.Join(root, "repo", ".git", "refs", "heads"),
// case 4: a submodule (.git is a file)
filepath.Join(root, "submodule"),
// case 6: directory name ending in .git (#711)
filepath.Join(root, "dotgit-suffix.git", "cheat"),
filepath.Join(root, "dotgit-mid.git", "nested"),
// .github (should not be confused with .git)
filepath.Join(root, ".github", "workflows"),
// generic hidden directory
filepath.Join(root, ".hidden"),
} }
for _, dir := range testDirs { for _, dir := range dirs {
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
t.Fatalf("failed to create dir %s: %v", dir, err) t.Fatalf("failed to create dir %s: %v", dir, err)
} }
} }
// Create test files files := map[string]string{
testFiles := map[string]string{ // sheets
filepath.Join(tempDir, ".gitignore"): "*.tmp\n", filepath.Join(root, "plain", "sheet"): "plain sheet",
filepath.Join(tempDir, ".gitattributes"): "* text=auto\n", filepath.Join(root, "repo", "sheet"): "repo sheet",
filepath.Join(tempDir, "submodule", ".git"): "gitdir: ../.git/modules/submodule\n", filepath.Join(root, "submodule", "sheet"): "submod sheet",
filepath.Join(tempDir, "regular", "sheet.txt"): "content\n", filepath.Join(root, "dotgit-suffix.git", "cheat", "sheet"): "dotgit sheet",
filepath.Join(root, "dotgit-mid.git", "nested", "sheet"): "dotgit nested",
filepath.Join(root, ".hidden", "sheet"): "hidden sheet",
// git metadata
filepath.Join(root, "repo", ".git", "HEAD"): "ref: refs/heads/main\n",
filepath.Join(root, "repo", ".gitignore"): "*.tmp\n",
filepath.Join(root, "repo", ".gitattributes"): "* text=auto\n",
filepath.Join(root, "submodule", ".git"): "gitdir: ../.git/modules/sub\n",
filepath.Join(root, ".github", "workflows", "ci.yml"): "name: CI\n",
} }
for file, content := range testFiles { for path, content := range files {
if err := os.WriteFile(file, []byte(content), 0644); err != nil { if err := os.WriteFile(path, []byte(content), 0644); err != nil {
t.Fatalf("failed to create file %s: %v", file, err) t.Fatalf("failed to write %s: %v", path, err)
} }
} }
tests := []struct { return root
name string
path string
want bool
wantErr bool
}{
{
name: "not in git directory",
path: filepath.Join(tempDir, "regular", "sheet.txt"),
want: false,
},
{
name: "in .git directory",
path: filepath.Join(tempDir, ".git", "objects", "file"),
want: true,
},
{
name: "in .git/refs directory",
path: filepath.Join(tempDir, ".git", "refs", "heads", "main"),
want: true,
},
{
name: ".gitignore file",
path: filepath.Join(tempDir, ".gitignore"),
want: false,
},
{
name: ".gitattributes file",
path: filepath.Join(tempDir, ".gitattributes"),
want: false,
},
{
name: "submodule with .git file",
path: filepath.Join(tempDir, "submodule", "sheet.txt"),
want: false,
},
{
name: "path with .git in middle",
path: filepath.Join(tempDir, "regular", ".git", "sheet.txt"),
want: true,
},
{
name: "nonexistent path without .git",
path: filepath.Join(tempDir, "nonexistent", "file"),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := GitDir(tt.path)
if (err != nil) != tt.wantErr {
t.Errorf("GitDir() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("GitDir() = %v, want %v", got, tt.want)
}
})
}
} }
func TestGitDirEdgeCases(t *testing.T) { func TestGitDir(t *testing.T) {
// Test with paths that have .git but not as a directory separator root := setupGitDirTestTree(t)
tests := []struct { tests := []struct {
name string name string
path string path string
want bool want bool
}{ }{
// Case 1: not a repository — no .git anywhere in path
{ {
name: "file ending with .git", name: "plain directory, no repo",
path: "/tmp/myfile.git", path: filepath.Join(root, "plain", "sheet"),
want: false,
},
// Case 2: a repository — paths *inside* .git/ should be detected
{
name: "inside .git directory",
path: filepath.Join(root, "repo", ".git", "HEAD"),
want: true,
},
{
name: "inside .git/objects",
path: filepath.Join(root, "repo", ".git", "objects", "pack", "somefile"),
want: true,
},
{
name: "inside .git/refs",
path: filepath.Join(root, "repo", ".git", "refs", "heads", "main"),
want: true,
},
// Case 2 (cont.): files *alongside* .git should NOT be detected
{
name: "sheet in repo root (beside .git dir)",
path: filepath.Join(root, "repo", "sheet"),
want: false,
},
// Case 3: .git* files (like .gitignore) should NOT trigger
{
name: ".gitignore file",
path: filepath.Join(root, "repo", ".gitignore"),
want: false, want: false,
}, },
{ {
name: "directory ending with .git", name: ".gitattributes file",
path: "/tmp/myrepo.git", path: filepath.Join(root, "repo", ".gitattributes"),
want: false,
},
// Case 4: submodule — .git is a file, not a directory
{
name: "sheet in submodule (where .git is a file)",
path: filepath.Join(root, "submodule", "sheet"),
want: false,
},
// Case 6: directory name ends with .git (#711)
{
name: "sheet under directory ending in .git",
path: filepath.Join(root, "dotgit-suffix.git", "cheat", "sheet"),
want: false, want: false,
}, },
{ {
name: ".github directory", name: "sheet under .git-suffixed dir, nested deeper",
path: "/tmp/.github/workflows", path: filepath.Join(root, "dotgit-mid.git", "nested", "sheet"),
want: false, want: false,
}, },
// .github directory — must not be confused with .git
{ {
name: "legitimate.git-repo name", name: "file inside .github directory",
path: "/tmp/legitimate.git-repo/file", path: filepath.Join(root, ".github", "workflows", "ci.yml"),
want: false,
},
// Hidden directory that is not .git
{
name: "file inside generic hidden directory",
path: filepath.Join(root, ".hidden", "sheet"),
want: false,
},
// Path with no .git at all
{
name: "path with no .git component whatsoever",
path: filepath.Join(root, "nonexistent", "file"),
want: false, want: false,
}, },
} }
@@ -140,8 +194,7 @@ func TestGitDirEdgeCases(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := GitDir(tt.path) got, err := GitDir(tt.path)
if err != nil { if err != nil {
// It's ok if the path doesn't exist for these edge case tests t.Fatalf("GitDir(%q) returned unexpected error: %v", tt.path, err)
return
} }
if got != tt.want { if got != tt.want {
t.Errorf("GitDir(%q) = %v, want %v", tt.path, got, tt.want) t.Errorf("GitDir(%q) = %v, want %v", tt.path, got, tt.want)
@@ -150,28 +203,153 @@ func TestGitDirEdgeCases(t *testing.T) {
} }
} }
func TestGitDirPathSeparator(t *testing.T) { // TestGitDirWithNestedGitDir tests a repo inside a .git-suffixed parent
// Test that the function correctly uses os.PathSeparator // directory. This is the nastiest combination: a real .git directory that
// This is important for cross-platform compatibility // appears *after* a .git suffix in the path.
func TestGitDirWithNestedGitDir(t *testing.T) {
root := t.TempDir()
// Create a path with the wrong separator for the current OS // Create: root/cheats.git/repo/.git/HEAD
var wrongSep string // root/cheats.git/repo/sheet
if os.PathSeparator == '/' { gitDir := filepath.Join(root, "cheats.git", "repo", ".git")
wrongSep = `\` if err := os.MkdirAll(gitDir, 0755); err != nil {
} else { t.Fatal(err)
wrongSep = `/` }
if err := os.WriteFile(filepath.Join(gitDir, "HEAD"), []byte("ref: refs/heads/main\n"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(root, "cheats.git", "repo", "sheet"), []byte("content"), 0644); err != nil {
t.Fatal(err)
} }
// Path with wrong separator should not be detected as git dir tests := []struct {
path := fmt.Sprintf("some%spath%s.git%sfile", wrongSep, wrongSep, wrongSep) name string
isGit, err := GitDir(path) path string
want bool
if err != nil { }{
// Path doesn't exist, which is fine {
return name: "sheet beside .git in .git-suffixed parent",
path: filepath.Join(root, "cheats.git", "repo", "sheet"),
want: false,
},
{
name: "file inside .git inside .git-suffixed parent",
path: filepath.Join(root, "cheats.git", "repo", ".git", "HEAD"),
want: true,
},
} }
if isGit { for _, tt := range tests {
t.Errorf("GitDir() incorrectly detected git dir with wrong path separator") t.Run(tt.name, func(t *testing.T) {
got, err := GitDir(tt.path)
if err != nil {
t.Fatalf("GitDir(%q) returned unexpected error: %v", tt.path, err)
}
if got != tt.want {
t.Errorf("GitDir(%q) = %v, want %v", tt.path, got, tt.want)
}
})
}
}
// TestGitDirSubmoduleInsideDotGitSuffix tests a submodule (.git file)
// inside a .git-suffixed parent directory.
func TestGitDirSubmoduleInsideDotGitSuffix(t *testing.T) {
root := t.TempDir()
// Create: root/personal.git/submod/.git (file)
// root/personal.git/submod/sheet
subDir := filepath.Join(root, "personal.git", "submod")
if err := os.MkdirAll(subDir, 0755); err != nil {
t.Fatal(err)
}
// .git as a file (submodule pointer)
if err := os.WriteFile(filepath.Join(subDir, ".git"), []byte("gitdir: ../../.git/modules/sub\n"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(subDir, "sheet"), []byte("content"), 0644); err != nil {
t.Fatal(err)
}
got, err := GitDir(filepath.Join(subDir, "sheet"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got {
t.Error("GitDir should return false for sheet in submodule under .git-suffixed parent")
}
}
// TestGitDirIntegrationWalk simulates what sheets.Load does: walking a
// directory tree and checking each path with GitDir. This verifies that
// the function works correctly in the context of filepath.Walk, which is
// how it is actually called.
func TestGitDirIntegrationWalk(t *testing.T) {
root := setupGitDirTestTree(t)
// Walk the tree and collect which paths GitDir says to skip
var skipped []string
var visited []string
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
isGit, err := GitDir(path)
if err != nil {
return err
}
if isGit {
skipped = append(skipped, path)
} else {
visited = append(visited, path)
}
return nil
})
if err != nil {
t.Fatalf("Walk failed: %v", err)
}
// Files inside .git/ should be skipped
expectSkipped := []string{
filepath.Join(root, "repo", ".git", "HEAD"),
}
for _, want := range expectSkipped {
found := false
for _, got := range skipped {
if got == want {
found = true
break
}
}
if !found {
t.Errorf("expected %q to be skipped, but it was not", want)
}
}
// Sheets should NOT be skipped — including the #711 case
expectVisited := []string{
filepath.Join(root, "plain", "sheet"),
filepath.Join(root, "repo", "sheet"),
filepath.Join(root, "submodule", "sheet"),
filepath.Join(root, "dotgit-suffix.git", "cheat", "sheet"),
filepath.Join(root, "dotgit-mid.git", "nested", "sheet"),
filepath.Join(root, ".hidden", "sheet"),
}
for _, want := range expectVisited {
found := false
for _, got := range visited {
if got == want {
found = true
break
}
}
if !found {
t.Errorf("expected %q to be visited (not skipped), but it was not found in visited paths", want)
}
} }
} }

View File

@@ -3,6 +3,7 @@ package sheet
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"testing" "testing"
) )
@@ -130,6 +131,10 @@ func TestCopyIOError(t *testing.T) {
// TestCopyCleanupOnError verifies that partially written files are cleaned up on error // TestCopyCleanupOnError verifies that partially written files are cleaned up on error
func TestCopyCleanupOnError(t *testing.T) { func TestCopyCleanupOnError(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("chmod does not restrict reads on Windows")
}
// Create a source file that we'll make unreadable after opening // Create a source file that we'll make unreadable after opening
src, err := os.CreateTemp("", "copy-test-cleanup-*") src, err := os.CreateTemp("", "copy-test-cleanup-*")
if err != nil { if err != nil {

View File

@@ -2,7 +2,6 @@ package sheet
import ( import (
"fmt" "fmt"
"runtime"
"strings" "strings"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -11,9 +10,9 @@ import (
// Parse parses cheatsheet frontmatter // Parse parses cheatsheet frontmatter
func parse(markdown string) (frontmatter, string, error) { func parse(markdown string) (frontmatter, string, error) {
// determine the appropriate line-break for the platform // detect the line-break style used in the content
linebreak := "\n" linebreak := "\n"
if runtime.GOOS == "windows" { if strings.Contains(markdown, "\r\n") {
linebreak = "\r\n" linebreak = "\r\n"
} }

View File

@@ -1,17 +1,11 @@
package sheet package sheet
import ( import (
"runtime"
"testing" "testing"
) )
// TestParseWindowsLineEndings tests parsing with Windows line endings // TestParseWindowsLineEndings tests parsing with Windows line endings
func TestParseWindowsLineEndings(t *testing.T) { func TestParseWindowsLineEndings(t *testing.T) {
// Only test Windows line endings on Windows
if runtime.GOOS != "windows" {
t.Skip("Skipping Windows line ending test on non-Windows platform")
}
// stub our cheatsheet content with Windows line endings // stub our cheatsheet content with Windows line endings
markdown := "---\r\nsyntax: go\r\ntags: [ test ]\r\n---\r\nTo foo the bar: baz" markdown := "---\r\nsyntax: go\r\ntags: [ test ]\r\n---\r\nTo foo the bar: baz"

View File

@@ -13,6 +13,9 @@
# Output of the go coverage tool, specifically when used with LiteIDE # Output of the go coverage tool, specifically when used with LiteIDE
*.out *.out
# Golang/Intellij
.idea
# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
.glide/ .glide/

7
vendor/dario.cat/mergo/FUNDING.json vendored Normal file
View File

@@ -0,0 +1,7 @@
{
"drips": {
"ethereum": {
"ownedBy": "0x6160020e7102237aC41bdb156e94401692D76930"
}
}
}

View File

@@ -44,13 +44,21 @@ Also a lovely [comune](http://en.wikipedia.org/wiki/Mergo) (municipality) in the
## Status ## Status
It is ready for production use. [It is used in several projects by Docker, Google, The Linux Foundation, VMWare, Shopify, Microsoft, etc](https://github.com/imdario/mergo#mergo-in-the-wild). Mergo is stable and frozen, ready for production. Check a short list of the projects using at large scale it [here](https://github.com/imdario/mergo#mergo-in-the-wild).
No new features are accepted. They will be considered for a future v2 that improves the implementation and fixes bugs for corner cases.
### Important notes ### Important notes
#### 1.0.0 #### 1.0.0
In [1.0.0](//github.com/imdario/mergo/releases/tag/1.0.0) Mergo moves to a vanity URL `dario.cat/mergo`. In [1.0.0](//github.com/imdario/mergo/releases/tag/1.0.0) Mergo moves to a vanity URL `dario.cat/mergo`. No more v1 versions will be released.
If the vanity URL is causing issues in your project due to a dependency pulling Mergo - it isn't a direct dependency in your project - it is recommended to use [replace](https://github.com/golang/go/wiki/Modules#when-should-i-use-the-replace-directive) to pin the version to the last one with the old import URL:
```
replace github.com/imdario/mergo => github.com/imdario/mergo v0.3.16
```
#### 0.3.9 #### 0.3.9
@@ -64,55 +72,23 @@ If you were using Mergo before April 6th, 2015, please check your project works
If Mergo is useful to you, consider buying me a coffee, a beer, or making a monthly donation to allow me to keep building great free software. :heart_eyes: If Mergo is useful to you, consider buying me a coffee, a beer, or making a monthly donation to allow me to keep building great free software. :heart_eyes:
<a href='https://ko-fi.com/B0B58839' target='_blank'><img height='36' style='border:0px;height:36px;' src='https://az743702.vo.msecnd.net/cdn/kofi1.png?v=0' border='0' alt='Buy Me a Coffee at ko-fi.com' /></a>
<a href="https://liberapay.com/dario/donate"><img alt="Donate using Liberapay" src="https://liberapay.com/assets/widgets/donate.svg"></a> <a href="https://liberapay.com/dario/donate"><img alt="Donate using Liberapay" src="https://liberapay.com/assets/widgets/donate.svg"></a>
<a href='https://github.com/sponsors/imdario' target='_blank'><img alt="Become my sponsor" src="https://img.shields.io/github/sponsors/imdario?style=for-the-badge" /></a> <a href='https://github.com/sponsors/imdario' target='_blank'><img alt="Become my sponsor" src="https://img.shields.io/github/sponsors/imdario?style=for-the-badge" /></a>
### Mergo in the wild ### Mergo in the wild
- [moby/moby](https://github.com/moby/moby) Mergo is used by [thousands](https://deps.dev/go/dario.cat%2Fmergo/v1.0.0/dependents) [of](https://deps.dev/go/github.com%2Fimdario%2Fmergo/v0.3.16/dependents) [projects](https://deps.dev/go/github.com%2Fimdario%2Fmergo/v0.3.12), including:
- [kubernetes/kubernetes](https://github.com/kubernetes/kubernetes)
- [vmware/dispatch](https://github.com/vmware/dispatch) * [containerd/containerd](https://github.com/containerd/containerd)
- [Shopify/themekit](https://github.com/Shopify/themekit) * [datadog/datadog-agent](https://github.com/datadog/datadog-agent)
- [imdario/zas](https://github.com/imdario/zas) * [docker/cli/](https://github.com/docker/cli/)
- [matcornic/hermes](https://github.com/matcornic/hermes) * [goreleaser/goreleaser](https://github.com/goreleaser/goreleaser)
- [OpenBazaar/openbazaar-go](https://github.com/OpenBazaar/openbazaar-go) * [go-micro/go-micro](https://github.com/go-micro/go-micro)
- [kataras/iris](https://github.com/kataras/iris) * [grafana/loki](https://github.com/grafana/loki)
- [michaelsauter/crane](https://github.com/michaelsauter/crane) * [masterminds/sprig](github.com/Masterminds/sprig)
- [go-task/task](https://github.com/go-task/task) * [moby/moby](https://github.com/moby/moby)
- [sensu/uchiwa](https://github.com/sensu/uchiwa) * [slackhq/nebula](https://github.com/slackhq/nebula)
- [ory/hydra](https://github.com/ory/hydra) * [volcano-sh/volcano](https://github.com/volcano-sh/volcano)
- [sisatech/vcli](https://github.com/sisatech/vcli)
- [dairycart/dairycart](https://github.com/dairycart/dairycart)
- [projectcalico/felix](https://github.com/projectcalico/felix)
- [resin-os/balena](https://github.com/resin-os/balena)
- [go-kivik/kivik](https://github.com/go-kivik/kivik)
- [Telefonica/govice](https://github.com/Telefonica/govice)
- [supergiant/supergiant](supergiant/supergiant)
- [SergeyTsalkov/brooce](https://github.com/SergeyTsalkov/brooce)
- [soniah/dnsmadeeasy](https://github.com/soniah/dnsmadeeasy)
- [ohsu-comp-bio/funnel](https://github.com/ohsu-comp-bio/funnel)
- [EagerIO/Stout](https://github.com/EagerIO/Stout)
- [lynndylanhurley/defsynth-api](https://github.com/lynndylanhurley/defsynth-api)
- [russross/canvasassignments](https://github.com/russross/canvasassignments)
- [rdegges/cryptly-api](https://github.com/rdegges/cryptly-api)
- [casualjim/exeggutor](https://github.com/casualjim/exeggutor)
- [divshot/gitling](https://github.com/divshot/gitling)
- [RWJMurphy/gorl](https://github.com/RWJMurphy/gorl)
- [andrerocker/deploy42](https://github.com/andrerocker/deploy42)
- [elwinar/rambler](https://github.com/elwinar/rambler)
- [tmaiaroto/gopartman](https://github.com/tmaiaroto/gopartman)
- [jfbus/impressionist](https://github.com/jfbus/impressionist)
- [Jmeyering/zealot](https://github.com/Jmeyering/zealot)
- [godep-migrator/rigger-host](https://github.com/godep-migrator/rigger-host)
- [Dronevery/MultiwaySwitch-Go](https://github.com/Dronevery/MultiwaySwitch-Go)
- [thoas/picfit](https://github.com/thoas/picfit)
- [mantasmatelis/whooplist-server](https://github.com/mantasmatelis/whooplist-server)
- [jnuthong/item_search](https://github.com/jnuthong/item_search)
- [bukalapak/snowboard](https://github.com/bukalapak/snowboard)
- [containerssh/containerssh](https://github.com/containerssh/containerssh)
- [goreleaser/goreleaser](https://github.com/goreleaser/goreleaser)
- [tjpnz/structbot](https://github.com/tjpnz/structbot)
## Install ## Install
@@ -141,6 +117,39 @@ if err := mergo.Merge(&dst, src, mergo.WithOverride); err != nil {
} }
``` ```
If you need to override pointers, so the source pointer's value is assigned to the destination's pointer, you must use `WithoutDereference`:
```go
package main
import (
"fmt"
"dario.cat/mergo"
)
type Foo struct {
A *string
B int64
}
func main() {
first := "first"
second := "second"
src := Foo{
A: &first,
B: 2,
}
dest := Foo{
A: &second,
B: 1,
}
mergo.Merge(&dest, src, mergo.WithOverride, mergo.WithoutDereference)
}
```
Additionally, you can map a `map[string]interface{}` to a struct (and otherwise, from struct to map), following the same restrictions as in `Merge()`. Keys are capitalized to find each corresponding exported field. Additionally, you can map a `map[string]interface{}` to a struct (and otherwise, from struct to map), following the same restrictions as in `Merge()`. Keys are capitalized to find each corresponding exported field.
```go ```go
@@ -181,10 +190,6 @@ func main() {
} }
``` ```
Note: if test are failing due missing package, please execute:
go get gopkg.in/yaml.v3
### Transformers ### Transformers
Transformers allow to merge specific types differently than in the default behavior. In other words, now you can customize how some types are merged. For example, `time.Time` is a struct; it doesn't have zero value but IsZero can return true because it has fields with zero value. How can we merge a non-zero `time.Time`? Transformers allow to merge specific types differently than in the default behavior. In other words, now you can customize how some types are merged. For example, `time.Time` is a struct; it doesn't have zero value but IsZero can return true because it has fields with zero value. How can we merge a non-zero `time.Time`?

View File

@@ -4,8 +4,8 @@
| Version | Supported | | Version | Supported |
| ------- | ------------------ | | ------- | ------------------ |
| 0.3.x | :white_check_mark: | | 1.x.x | :white_check_mark: |
| < 0.3 | :x: | | < 1.0 | :x: |
## Security contact information ## Security contact information

View File

@@ -58,7 +58,7 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, conf
} }
fieldName := field.Name fieldName := field.Name
fieldName = changeInitialCase(fieldName, unicode.ToLower) fieldName = changeInitialCase(fieldName, unicode.ToLower)
if v, ok := dstMap[fieldName]; !ok || (isEmptyValue(reflect.ValueOf(v), !config.ShouldNotDereference) || overwrite) { if _, ok := dstMap[fieldName]; !ok || (!isEmptyValue(reflect.ValueOf(src.Field(i).Interface()), !config.ShouldNotDereference) && overwrite) || config.overwriteWithEmptyValue {
dstMap[fieldName] = src.Field(i).Interface() dstMap[fieldName] = src.Field(i).Interface()
} }
} }

View File

@@ -269,7 +269,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, co
if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil { if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, config); err != nil {
return return
} }
} else { } else if src.Elem().Kind() != reflect.Struct {
if overwriteWithEmptySrc || (overwrite && !src.IsNil()) || dst.IsNil() { if overwriteWithEmptySrc || (overwrite && !src.IsNil()) || dst.IsNil() {
dst.Set(src) dst.Set(src)
} }

View File

@@ -1,7 +1,3 @@
run:
skip-dirs:
- pkg/etw/sample
linters: linters:
enable: enable:
# style # style
@@ -20,9 +16,13 @@ linters:
- gofmt # files are gofmt'ed - gofmt # files are gofmt'ed
- gosec # security - gosec # security
- nilerr # returns nil even with non-nil error - nilerr # returns nil even with non-nil error
- thelper # test helpers without t.Helper()
- unparam # unused function params - unparam # unused function params
issues: issues:
exclude-dirs:
- pkg/etw/sample
exclude-rules: exclude-rules:
# err is very often shadowed in nested scopes # err is very often shadowed in nested scopes
- linters: - linters:
@@ -69,9 +69,7 @@ linters-settings:
# struct order is often for Win32 compat # struct order is often for Win32 compat
# also, ignore pointer bytes/GC issues for now until performance becomes an issue # also, ignore pointer bytes/GC issues for now until performance becomes an issue
- fieldalignment - fieldalignment
check-shadowing: true
nolintlint: nolintlint:
allow-leading-space: false
require-explanation: true require-explanation: true
require-specific: true require-specific: true
revive: revive:

View File

@@ -10,14 +10,14 @@ import (
"io" "io"
"os" "os"
"runtime" "runtime"
"syscall"
"unicode/utf16" "unicode/utf16"
"github.com/Microsoft/go-winio/internal/fs"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
//sys backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead //sys backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead
//sys backupWrite(h syscall.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite //sys backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite
const ( const (
BackupData = uint32(iota + 1) BackupData = uint32(iota + 1)
@@ -104,7 +104,7 @@ func (r *BackupStreamReader) Next() (*BackupHeader, error) {
if err := binary.Read(r.r, binary.LittleEndian, name); err != nil { if err := binary.Read(r.r, binary.LittleEndian, name); err != nil {
return nil, err return nil, err
} }
hdr.Name = syscall.UTF16ToString(name) hdr.Name = windows.UTF16ToString(name)
} }
if wsi.StreamID == BackupSparseBlock { if wsi.StreamID == BackupSparseBlock {
if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil { if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil {
@@ -205,7 +205,7 @@ func NewBackupFileReader(f *os.File, includeSecurity bool) *BackupFileReader {
// Read reads a backup stream from the file by calling the Win32 API BackupRead(). // Read reads a backup stream from the file by calling the Win32 API BackupRead().
func (r *BackupFileReader) Read(b []byte) (int, error) { func (r *BackupFileReader) Read(b []byte) (int, error) {
var bytesRead uint32 var bytesRead uint32
err := backupRead(syscall.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx) err := backupRead(windows.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx)
if err != nil { if err != nil {
return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err} return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err}
} }
@@ -220,7 +220,7 @@ func (r *BackupFileReader) Read(b []byte) (int, error) {
// the underlying file. // the underlying file.
func (r *BackupFileReader) Close() error { func (r *BackupFileReader) Close() error {
if r.ctx != 0 { if r.ctx != 0 {
_ = backupRead(syscall.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx) _ = backupRead(windows.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx)
runtime.KeepAlive(r.f) runtime.KeepAlive(r.f)
r.ctx = 0 r.ctx = 0
} }
@@ -244,7 +244,7 @@ func NewBackupFileWriter(f *os.File, includeSecurity bool) *BackupFileWriter {
// Write restores a portion of the file using the provided backup stream. // Write restores a portion of the file using the provided backup stream.
func (w *BackupFileWriter) Write(b []byte) (int, error) { func (w *BackupFileWriter) Write(b []byte) (int, error) {
var bytesWritten uint32 var bytesWritten uint32
err := backupWrite(syscall.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx) err := backupWrite(windows.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx)
if err != nil { if err != nil {
return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err} return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err}
} }
@@ -259,7 +259,7 @@ func (w *BackupFileWriter) Write(b []byte) (int, error) {
// close the underlying file. // close the underlying file.
func (w *BackupFileWriter) Close() error { func (w *BackupFileWriter) Close() error {
if w.ctx != 0 { if w.ctx != 0 {
_ = backupWrite(syscall.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx) _ = backupWrite(windows.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx)
runtime.KeepAlive(w.f) runtime.KeepAlive(w.f)
w.ctx = 0 w.ctx = 0
} }
@@ -271,17 +271,14 @@ func (w *BackupFileWriter) Close() error {
// //
// If the file opened was a directory, it cannot be used with Readdir(). // If the file opened was a directory, it cannot be used with Readdir().
func OpenForBackup(path string, access uint32, share uint32, createmode uint32) (*os.File, error) { func OpenForBackup(path string, access uint32, share uint32, createmode uint32) (*os.File, error) {
winPath, err := syscall.UTF16FromString(path) h, err := fs.CreateFile(path,
if err != nil { fs.AccessMask(access),
return nil, err fs.FileShareMode(share),
}
h, err := syscall.CreateFile(&winPath[0],
access,
share,
nil, nil,
createmode, fs.FileCreationDisposition(createmode),
syscall.FILE_FLAG_BACKUP_SEMANTICS|syscall.FILE_FLAG_OPEN_REPARSE_POINT, fs.FILE_FLAG_BACKUP_SEMANTICS|fs.FILE_FLAG_OPEN_REPARSE_POINT,
0) 0,
)
if err != nil { if err != nil {
err = &os.PathError{Op: "open", Path: path, Err: err} err = &os.PathError{Op: "open", Path: path, Err: err}
return nil, err return nil, err

View File

@@ -15,26 +15,11 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx //sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort //sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus //sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes //sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult //sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
type atomicBool int32
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
//revive:disable-next-line:predeclared Keep "new" to maintain consistency with "atomic" pkg
func (b *atomicBool) swap(new bool) bool {
var newInt int32
if new {
newInt = 1
}
return atomic.SwapInt32((*int32)(b), newInt) == 1
}
var ( var (
ErrFileClosed = errors.New("file has already been closed") ErrFileClosed = errors.New("file has already been closed")
@@ -50,7 +35,7 @@ func (*timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{} type timeoutChan chan struct{}
var ioInitOnce sync.Once var ioInitOnce sync.Once
var ioCompletionPort syscall.Handle var ioCompletionPort windows.Handle
// ioResult contains the result of an asynchronous IO operation. // ioResult contains the result of an asynchronous IO operation.
type ioResult struct { type ioResult struct {
@@ -60,12 +45,12 @@ type ioResult struct {
// ioOperation represents an outstanding asynchronous Win32 IO. // ioOperation represents an outstanding asynchronous Win32 IO.
type ioOperation struct { type ioOperation struct {
o syscall.Overlapped o windows.Overlapped
ch chan ioResult ch chan ioResult
} }
func initIO() { func initIO() {
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff) h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -76,10 +61,10 @@ func initIO() {
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. // win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected. // It takes ownership of this handle and will close it if it is garbage collected.
type win32File struct { type win32File struct {
handle syscall.Handle handle windows.Handle
wg sync.WaitGroup wg sync.WaitGroup
wgLock sync.RWMutex wgLock sync.RWMutex
closing atomicBool closing atomic.Bool
socket bool socket bool
readDeadline deadlineHandler readDeadline deadlineHandler
writeDeadline deadlineHandler writeDeadline deadlineHandler
@@ -90,11 +75,11 @@ type deadlineHandler struct {
channel timeoutChan channel timeoutChan
channelLock sync.RWMutex channelLock sync.RWMutex
timer *time.Timer timer *time.Timer
timedout atomicBool timedout atomic.Bool
} }
// makeWin32File makes a new win32File from an existing file handle. // makeWin32File makes a new win32File from an existing file handle.
func makeWin32File(h syscall.Handle) (*win32File, error) { func makeWin32File(h windows.Handle) (*win32File, error) {
f := &win32File{handle: h} f := &win32File{handle: h}
ioInitOnce.Do(initIO) ioInitOnce.Do(initIO)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
@@ -110,7 +95,12 @@ func makeWin32File(h syscall.Handle) (*win32File, error) {
return f, nil return f, nil
} }
// Deprecated: use NewOpenFile instead.
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) { func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
return NewOpenFile(windows.Handle(h))
}
func NewOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
// If we return the result of makeWin32File directly, it can result in an // If we return the result of makeWin32File directly, it can result in an
// interface-wrapped nil, rather than a nil interface value. // interface-wrapped nil, rather than a nil interface value.
f, err := makeWin32File(h) f, err := makeWin32File(h)
@@ -124,13 +114,13 @@ func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
func (f *win32File) closeHandle() { func (f *win32File) closeHandle() {
f.wgLock.Lock() f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once. // Atomically set that we are closing, releasing the resources only once.
if !f.closing.swap(true) { if !f.closing.Swap(true) {
f.wgLock.Unlock() f.wgLock.Unlock()
// cancel all IO and wait for it to complete // cancel all IO and wait for it to complete
_ = cancelIoEx(f.handle, nil) _ = cancelIoEx(f.handle, nil)
f.wg.Wait() f.wg.Wait()
// at this point, no new IO can start // at this point, no new IO can start
syscall.Close(f.handle) windows.Close(f.handle)
f.handle = 0 f.handle = 0
} else { } else {
f.wgLock.Unlock() f.wgLock.Unlock()
@@ -145,14 +135,14 @@ func (f *win32File) Close() error {
// IsClosed checks if the file has been closed. // IsClosed checks if the file has been closed.
func (f *win32File) IsClosed() bool { func (f *win32File) IsClosed() bool {
return f.closing.isSet() return f.closing.Load()
} }
// prepareIO prepares for a new IO operation. // prepareIO prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *win32File) prepareIO() (*ioOperation, error) { func (f *win32File) prepareIO() (*ioOperation, error) {
f.wgLock.RLock() f.wgLock.RLock()
if f.closing.isSet() { if f.closing.Load() {
f.wgLock.RUnlock() f.wgLock.RUnlock()
return nil, ErrFileClosed return nil, ErrFileClosed
} }
@@ -164,12 +154,12 @@ func (f *win32File) prepareIO() (*ioOperation, error) {
} }
// ioCompletionProcessor processes completed async IOs forever. // ioCompletionProcessor processes completed async IOs forever.
func ioCompletionProcessor(h syscall.Handle) { func ioCompletionProcessor(h windows.Handle) {
for { for {
var bytes uint32 var bytes uint32
var key uintptr var key uintptr
var op *ioOperation var op *ioOperation
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE) err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
if op == nil { if op == nil {
panic(err) panic(err)
} }
@@ -182,11 +172,11 @@ func ioCompletionProcessor(h syscall.Handle) {
// asyncIO processes the return value from ReadFile or WriteFile, blocking until // asyncIO processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed. // the operation has actually completed.
func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING { //nolint:errorlint // err is Errno if err != windows.ERROR_IO_PENDING { //nolint:errorlint // err is Errno
return int(bytes), err return int(bytes), err
} }
if f.closing.isSet() { if f.closing.Load() {
_ = cancelIoEx(f.handle, &c.o) _ = cancelIoEx(f.handle, &c.o)
} }
@@ -201,8 +191,8 @@ func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, er
select { select {
case r = <-c.ch: case r = <-c.ch:
err = r.err err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
if f.closing.isSet() { if f.closing.Load() {
err = ErrFileClosed err = ErrFileClosed
} }
} else if err != nil && f.socket { } else if err != nil && f.socket {
@@ -214,7 +204,7 @@ func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, er
_ = cancelIoEx(f.handle, &c.o) _ = cancelIoEx(f.handle, &c.o)
r = <-c.ch r = <-c.ch
err = r.err err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
err = ErrTimeout err = ErrTimeout
} }
} }
@@ -235,23 +225,22 @@ func (f *win32File) Read(b []byte) (int, error) {
} }
defer f.wg.Done() defer f.wg.Done()
if f.readDeadline.timedout.isSet() { if f.readDeadline.timedout.Load() {
return 0, ErrTimeout return 0, ErrTimeout
} }
var bytes uint32 var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o) err = windows.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIO(c, &f.readDeadline, bytes, err) n, err := f.asyncIO(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b) runtime.KeepAlive(b)
// Handle EOF conditions. // Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 { if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF return 0, io.EOF
} else if err == syscall.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno } else if err == windows.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno
return 0, io.EOF return 0, io.EOF
} else {
return n, err
} }
return n, err
} }
// Write writes to a file handle. // Write writes to a file handle.
@@ -262,12 +251,12 @@ func (f *win32File) Write(b []byte) (int, error) {
} }
defer f.wg.Done() defer f.wg.Done()
if f.writeDeadline.timedout.isSet() { if f.writeDeadline.timedout.Load() {
return 0, ErrTimeout return 0, ErrTimeout
} }
var bytes uint32 var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o) err = windows.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIO(c, &f.writeDeadline, bytes, err) n, err := f.asyncIO(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b) runtime.KeepAlive(b)
return n, err return n, err
@@ -282,7 +271,7 @@ func (f *win32File) SetWriteDeadline(deadline time.Time) error {
} }
func (f *win32File) Flush() error { func (f *win32File) Flush() error {
return syscall.FlushFileBuffers(f.handle) return windows.FlushFileBuffers(f.handle)
} }
func (f *win32File) Fd() uintptr { func (f *win32File) Fd() uintptr {
@@ -299,7 +288,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
} }
d.timer = nil d.timer = nil
} }
d.timedout.setFalse() d.timedout.Store(false)
select { select {
case <-d.channel: case <-d.channel:
@@ -314,7 +303,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
} }
timeoutIO := func() { timeoutIO := func() {
d.timedout.setTrue() d.timedout.Store(true)
close(d.channel) close(d.channel)
} }

View File

@@ -18,9 +18,18 @@ type FileBasicInfo struct {
_ uint32 // padding _ uint32 // padding
} }
// alignedFileBasicInfo is a FileBasicInfo, but aligned to uint64 by containing
// uint64 rather than windows.Filetime. Filetime contains two uint32s. uint64
// alignment is necessary to pass this as FILE_BASIC_INFO.
type alignedFileBasicInfo struct {
CreationTime, LastAccessTime, LastWriteTime, ChangeTime uint64
FileAttributes uint32
_ uint32 // padding
}
// GetFileBasicInfo retrieves times and attributes for a file. // GetFileBasicInfo retrieves times and attributes for a file.
func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) { func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
bi := &FileBasicInfo{} bi := &alignedFileBasicInfo{}
if err := windows.GetFileInformationByHandleEx( if err := windows.GetFileInformationByHandleEx(
windows.Handle(f.Fd()), windows.Handle(f.Fd()),
windows.FileBasicInfo, windows.FileBasicInfo,
@@ -30,16 +39,21 @@ func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
} }
runtime.KeepAlive(f) runtime.KeepAlive(f)
return bi, nil // Reinterpret the alignedFileBasicInfo as a FileBasicInfo so it matches the
// public API of this module. The data may be unnecessarily aligned.
return (*FileBasicInfo)(unsafe.Pointer(bi)), nil
} }
// SetFileBasicInfo sets times and attributes for a file. // SetFileBasicInfo sets times and attributes for a file.
func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error { func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error {
// Create an alignedFileBasicInfo based on a FileBasicInfo. The copy is
// suitable to pass to GetFileInformationByHandleEx.
biAligned := *(*alignedFileBasicInfo)(unsafe.Pointer(bi))
if err := windows.SetFileInformationByHandle( if err := windows.SetFileInformationByHandle(
windows.Handle(f.Fd()), windows.Handle(f.Fd()),
windows.FileBasicInfo, windows.FileBasicInfo,
(*byte)(unsafe.Pointer(bi)), (*byte)(unsafe.Pointer(&biAligned)),
uint32(unsafe.Sizeof(*bi)), uint32(unsafe.Sizeof(biAligned)),
); err != nil { ); err != nil {
return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err} return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err}
} }

View File

@@ -10,7 +10,6 @@ import (
"io" "io"
"net" "net"
"os" "os"
"syscall"
"time" "time"
"unsafe" "unsafe"
@@ -181,13 +180,13 @@ type HvsockConn struct {
var _ net.Conn = &HvsockConn{} var _ net.Conn = &HvsockConn{}
func newHVSocket() (*win32File, error) { func newHVSocket() (*win32File, error) {
fd, err := syscall.Socket(afHVSock, syscall.SOCK_STREAM, 1) fd, err := windows.Socket(afHVSock, windows.SOCK_STREAM, 1)
if err != nil { if err != nil {
return nil, os.NewSyscallError("socket", err) return nil, os.NewSyscallError("socket", err)
} }
f, err := makeWin32File(fd) f, err := makeWin32File(fd)
if err != nil { if err != nil {
syscall.Close(fd) windows.Close(fd)
return nil, err return nil, err
} }
f.socket = true f.socket = true
@@ -197,16 +196,24 @@ func newHVSocket() (*win32File, error) {
// ListenHvsock listens for connections on the specified hvsock address. // ListenHvsock listens for connections on the specified hvsock address.
func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) { func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
l := &HvsockListener{addr: *addr} l := &HvsockListener{addr: *addr}
sock, err := newHVSocket()
var sock *win32File
sock, err = newHVSocket()
if err != nil { if err != nil {
return nil, l.opErr("listen", err) return nil, l.opErr("listen", err)
} }
defer func() {
if err != nil {
_ = sock.Close()
}
}()
sa := addr.raw() sa := addr.raw()
err = socket.Bind(windows.Handle(sock.handle), &sa) err = socket.Bind(sock.handle, &sa)
if err != nil { if err != nil {
return nil, l.opErr("listen", os.NewSyscallError("socket", err)) return nil, l.opErr("listen", os.NewSyscallError("socket", err))
} }
err = syscall.Listen(sock.handle, 16) err = windows.Listen(sock.handle, 16)
if err != nil { if err != nil {
return nil, l.opErr("listen", os.NewSyscallError("listen", err)) return nil, l.opErr("listen", os.NewSyscallError("listen", err))
} }
@@ -246,7 +253,7 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) {
var addrbuf [addrlen * 2]byte var addrbuf [addrlen * 2]byte
var bytes uint32 var bytes uint32
err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o) err = windows.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o)
if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil { if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
} }
@@ -263,7 +270,7 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) {
conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen]))) conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
// initialize the accepted socket and update its properties with those of the listening socket // initialize the accepted socket and update its properties with those of the listening socket
if err = windows.Setsockopt(windows.Handle(sock.handle), if err = windows.Setsockopt(sock.handle,
windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT, windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
(*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil { (*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err)) return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
@@ -334,7 +341,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock
}() }()
sa := addr.raw() sa := addr.raw()
err = socket.Bind(windows.Handle(sock.handle), &sa) err = socket.Bind(sock.handle, &sa)
if err != nil { if err != nil {
return nil, conn.opErr(op, os.NewSyscallError("bind", err)) return nil, conn.opErr(op, os.NewSyscallError("bind", err))
} }
@@ -347,7 +354,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock
var bytes uint32 var bytes uint32
for i := uint(0); i <= d.Retries; i++ { for i := uint(0); i <= d.Retries; i++ {
err = socket.ConnectEx( err = socket.ConnectEx(
windows.Handle(sock.handle), sock.handle,
&sa, &sa,
nil, // sendBuf nil, // sendBuf
0, // sendDataLen 0, // sendDataLen
@@ -367,7 +374,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock
// update the connection properties, so shutdown can be used // update the connection properties, so shutdown can be used
if err = windows.Setsockopt( if err = windows.Setsockopt(
windows.Handle(sock.handle), sock.handle,
windows.SOL_SOCKET, windows.SOL_SOCKET,
windows.SO_UPDATE_CONNECT_CONTEXT, windows.SO_UPDATE_CONNECT_CONTEXT,
nil, // optvalue nil, // optvalue
@@ -378,7 +385,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock
// get the local name // get the local name
var sal rawHvsockAddr var sal rawHvsockAddr
err = socket.GetSockName(windows.Handle(sock.handle), &sal) err = socket.GetSockName(sock.handle, &sal)
if err != nil { if err != nil {
return nil, conn.opErr(op, os.NewSyscallError("getsockname", err)) return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
} }
@@ -421,7 +428,7 @@ func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
return ctx.Err() return ctx.Err()
} }
// assumes error is a plain, unwrapped syscall.Errno provided by direct syscall. // assumes error is a plain, unwrapped windows.Errno provided by direct syscall.
func canRedial(err error) bool { func canRedial(err error) bool {
//nolint:errorlint // guaranteed to be an Errno //nolint:errorlint // guaranteed to be an Errno
switch err { switch err {
@@ -447,9 +454,9 @@ func (conn *HvsockConn) Read(b []byte) (int, error) {
return 0, conn.opErr("read", err) return 0, conn.opErr("read", err)
} }
defer conn.sock.wg.Done() defer conn.sock.wg.Done()
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
var flags, bytes uint32 var flags, bytes uint32
err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) err = windows.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err) n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
if err != nil { if err != nil {
var eno windows.Errno var eno windows.Errno
@@ -482,9 +489,9 @@ func (conn *HvsockConn) write(b []byte) (int, error) {
return 0, conn.opErr("write", err) return 0, conn.opErr("write", err)
} }
defer conn.sock.wg.Done() defer conn.sock.wg.Done()
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
var bytes uint32 var bytes uint32
err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) err = windows.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err) n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
if err != nil { if err != nil {
var eno windows.Errno var eno windows.Errno
@@ -511,7 +518,7 @@ func (conn *HvsockConn) shutdown(how int) error {
return socket.ErrSocketClosed return socket.ErrSocketClosed
} }
err := syscall.Shutdown(conn.sock.handle, how) err := windows.Shutdown(conn.sock.handle, how)
if err != nil { if err != nil {
// If the connection was closed, shutdowns fail with "not connected" // If the connection was closed, shutdowns fail with "not connected"
if errors.Is(err, windows.WSAENOTCONN) || if errors.Is(err, windows.WSAENOTCONN) ||
@@ -525,7 +532,7 @@ func (conn *HvsockConn) shutdown(how int) error {
// CloseRead shuts down the read end of the socket, preventing future read operations. // CloseRead shuts down the read end of the socket, preventing future read operations.
func (conn *HvsockConn) CloseRead() error { func (conn *HvsockConn) CloseRead() error {
err := conn.shutdown(syscall.SHUT_RD) err := conn.shutdown(windows.SHUT_RD)
if err != nil { if err != nil {
return conn.opErr("closeread", err) return conn.opErr("closeread", err)
} }
@@ -535,7 +542,7 @@ func (conn *HvsockConn) CloseRead() error {
// CloseWrite shuts down the write end of the socket, preventing future write operations and // CloseWrite shuts down the write end of the socket, preventing future write operations and
// notifying the other endpoint that no more data will be written. // notifying the other endpoint that no more data will be written.
func (conn *HvsockConn) CloseWrite() error { func (conn *HvsockConn) CloseWrite() error {
err := conn.shutdown(syscall.SHUT_WR) err := conn.shutdown(windows.SHUT_WR)
if err != nil { if err != nil {
return conn.opErr("closewrite", err) return conn.opErr("closewrite", err)
} }

View File

@@ -11,12 +11,14 @@ import (
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go fs.go //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go fs.go
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew // https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew
//sys CreateFile(name string, access AccessMask, mode FileShareMode, sa *syscall.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW //sys CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
const NullHandle windows.Handle = 0 const NullHandle windows.Handle = 0
// AccessMask defines standard, specific, and generic rights. // AccessMask defines standard, specific, and generic rights.
// //
// Used with CreateFile and NtCreateFile (and co.).
//
// Bitmask: // Bitmask:
// 3 3 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 // 3 3 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1
// 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 // 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0 9 8 7 6 5 4 3 2 1 0
@@ -47,6 +49,12 @@ const (
// https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew#parameters // https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew#parameters
FILE_ANY_ACCESS AccessMask = 0 FILE_ANY_ACCESS AccessMask = 0
GENERIC_READ AccessMask = 0x8000_0000
GENERIC_WRITE AccessMask = 0x4000_0000
GENERIC_EXECUTE AccessMask = 0x2000_0000
GENERIC_ALL AccessMask = 0x1000_0000
ACCESS_SYSTEM_SECURITY AccessMask = 0x0100_0000
// Specific Object Access // Specific Object Access
// from ntioapi.h // from ntioapi.h
@@ -124,14 +132,32 @@ const (
TRUNCATE_EXISTING FileCreationDisposition = 0x05 TRUNCATE_EXISTING FileCreationDisposition = 0x05
) )
// Create disposition values for NtCreate*
type NTFileCreationDisposition uint32
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
const (
// From ntioapi.h
FILE_SUPERSEDE NTFileCreationDisposition = 0x00
FILE_OPEN NTFileCreationDisposition = 0x01
FILE_CREATE NTFileCreationDisposition = 0x02
FILE_OPEN_IF NTFileCreationDisposition = 0x03
FILE_OVERWRITE NTFileCreationDisposition = 0x04
FILE_OVERWRITE_IF NTFileCreationDisposition = 0x05
FILE_MAXIMUM_DISPOSITION NTFileCreationDisposition = 0x05
)
// CreateFile and co. take flags or attributes together as one parameter. // CreateFile and co. take flags or attributes together as one parameter.
// Define alias until we can use generics to allow both // Define alias until we can use generics to allow both
//
// https://learn.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants // https://learn.microsoft.com/en-us/windows/win32/fileio/file-attribute-constants
type FileFlagOrAttribute uint32 type FileFlagOrAttribute uint32
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. //nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
const ( // from winnt.h const (
// from winnt.h
FILE_FLAG_WRITE_THROUGH FileFlagOrAttribute = 0x8000_0000 FILE_FLAG_WRITE_THROUGH FileFlagOrAttribute = 0x8000_0000
FILE_FLAG_OVERLAPPED FileFlagOrAttribute = 0x4000_0000 FILE_FLAG_OVERLAPPED FileFlagOrAttribute = 0x4000_0000
FILE_FLAG_NO_BUFFERING FileFlagOrAttribute = 0x2000_0000 FILE_FLAG_NO_BUFFERING FileFlagOrAttribute = 0x2000_0000
@@ -145,17 +171,51 @@ const ( // from winnt.h
FILE_FLAG_FIRST_PIPE_INSTANCE FileFlagOrAttribute = 0x0008_0000 FILE_FLAG_FIRST_PIPE_INSTANCE FileFlagOrAttribute = 0x0008_0000
) )
// NtCreate* functions take a dedicated CreateOptions parameter.
//
// https://learn.microsoft.com/en-us/windows/win32/api/Winternl/nf-winternl-ntcreatefile
//
// https://learn.microsoft.com/en-us/windows/win32/devnotes/nt-create-named-pipe-file
type NTCreateOptions uint32
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
const (
// From ntioapi.h
FILE_DIRECTORY_FILE NTCreateOptions = 0x0000_0001
FILE_WRITE_THROUGH NTCreateOptions = 0x0000_0002
FILE_SEQUENTIAL_ONLY NTCreateOptions = 0x0000_0004
FILE_NO_INTERMEDIATE_BUFFERING NTCreateOptions = 0x0000_0008
FILE_SYNCHRONOUS_IO_ALERT NTCreateOptions = 0x0000_0010
FILE_SYNCHRONOUS_IO_NONALERT NTCreateOptions = 0x0000_0020
FILE_NON_DIRECTORY_FILE NTCreateOptions = 0x0000_0040
FILE_CREATE_TREE_CONNECTION NTCreateOptions = 0x0000_0080
FILE_COMPLETE_IF_OPLOCKED NTCreateOptions = 0x0000_0100
FILE_NO_EA_KNOWLEDGE NTCreateOptions = 0x0000_0200
FILE_DISABLE_TUNNELING NTCreateOptions = 0x0000_0400
FILE_RANDOM_ACCESS NTCreateOptions = 0x0000_0800
FILE_DELETE_ON_CLOSE NTCreateOptions = 0x0000_1000
FILE_OPEN_BY_FILE_ID NTCreateOptions = 0x0000_2000
FILE_OPEN_FOR_BACKUP_INTENT NTCreateOptions = 0x0000_4000
FILE_NO_COMPRESSION NTCreateOptions = 0x0000_8000
)
type FileSQSFlag = FileFlagOrAttribute type FileSQSFlag = FileFlagOrAttribute
//nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API. //nolint:revive // SNAKE_CASE is not idiomatic in Go, but aligned with Win32 API.
const ( // from winbase.h const (
// from winbase.h
SECURITY_ANONYMOUS FileSQSFlag = FileSQSFlag(SecurityAnonymous << 16) SECURITY_ANONYMOUS FileSQSFlag = FileSQSFlag(SecurityAnonymous << 16)
SECURITY_IDENTIFICATION FileSQSFlag = FileSQSFlag(SecurityIdentification << 16) SECURITY_IDENTIFICATION FileSQSFlag = FileSQSFlag(SecurityIdentification << 16)
SECURITY_IMPERSONATION FileSQSFlag = FileSQSFlag(SecurityImpersonation << 16) SECURITY_IMPERSONATION FileSQSFlag = FileSQSFlag(SecurityImpersonation << 16)
SECURITY_DELEGATION FileSQSFlag = FileSQSFlag(SecurityDelegation << 16) SECURITY_DELEGATION FileSQSFlag = FileSQSFlag(SecurityDelegation << 16)
SECURITY_SQOS_PRESENT FileSQSFlag = 0x00100000 SECURITY_SQOS_PRESENT FileSQSFlag = 0x0010_0000
SECURITY_VALID_SQOS_FLAGS FileSQSFlag = 0x001F0000 SECURITY_VALID_SQOS_FLAGS FileSQSFlag = 0x001F_0000
) )
// GetFinalPathNameByHandle flags // GetFinalPathNameByHandle flags

View File

@@ -33,9 +33,6 @@ func errnoErr(e syscall.Errno) error {
case errnoERROR_IO_PENDING: case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING return errERROR_IO_PENDING
} }
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e return e
} }
@@ -45,7 +42,7 @@ var (
procCreateFileW = modkernel32.NewProc("CreateFileW") procCreateFileW = modkernel32.NewProc("CreateFileW")
) )
func CreateFile(name string, access AccessMask, mode FileShareMode, sa *syscall.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) { func CreateFile(name string, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) {
var _p0 *uint16 var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name) _p0, err = syscall.UTF16PtrFromString(name)
if err != nil { if err != nil {
@@ -54,8 +51,8 @@ func CreateFile(name string, access AccessMask, mode FileShareMode, sa *syscall.
return _CreateFile(_p0, access, mode, sa, createmode, attrs, templatefile) return _CreateFile(_p0, access, mode, sa, createmode, attrs, templatefile)
} }
func _CreateFile(name *uint16, access AccessMask, mode FileShareMode, sa *syscall.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) { func _CreateFile(name *uint16, access AccessMask, mode FileShareMode, sa *windows.SecurityAttributes, createmode FileCreationDisposition, attrs FileFlagOrAttribute, templatefile windows.Handle) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0) r0, _, e1 := syscall.SyscallN(procCreateFileW.Addr(), uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile))
handle = windows.Handle(r0) handle = windows.Handle(r0)
if handle == windows.InvalidHandle { if handle == windows.InvalidHandle {
err = errnoErr(e1) err = errnoErr(e1)

View File

@@ -156,9 +156,7 @@ func connectEx(
bytesSent *uint32, bytesSent *uint32,
overlapped *windows.Overlapped, overlapped *windows.Overlapped,
) (err error) { ) (err error) {
// todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN r1, _, e1 := syscall.SyscallN(connectExFunc.addr,
r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
7,
uintptr(s), uintptr(s),
uintptr(name), uintptr(name),
uintptr(namelen), uintptr(namelen),
@@ -166,8 +164,8 @@ func connectEx(
uintptr(sendDataLen), uintptr(sendDataLen),
uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(bytesSent)),
uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(overlapped)),
0, )
0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
err = error(e1) err = error(e1)

View File

@@ -33,9 +33,6 @@ func errnoErr(e syscall.Errno) error {
case errnoERROR_IO_PENDING: case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING return errERROR_IO_PENDING
} }
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e return e
} }
@@ -48,7 +45,7 @@ var (
) )
func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) { func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) {
r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen)) r1, _, e1 := syscall.SyscallN(procbind.Addr(), uintptr(s), uintptr(name), uintptr(namelen))
if r1 == socketError { if r1 == socketError {
err = errnoErr(e1) err = errnoErr(e1)
} }
@@ -56,7 +53,7 @@ func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) {
} }
func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) { func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
r1, _, e1 := syscall.Syscall(procgetpeername.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen))) r1, _, e1 := syscall.SyscallN(procgetpeername.Addr(), uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
if r1 == socketError { if r1 == socketError {
err = errnoErr(e1) err = errnoErr(e1)
} }
@@ -64,7 +61,7 @@ func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err err
} }
func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) { func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
r1, _, e1 := syscall.Syscall(procgetsockname.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen))) r1, _, e1 := syscall.SyscallN(procgetsockname.Addr(), uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
if r1 == socketError { if r1 == socketError {
err = errnoErr(e1) err = errnoErr(e1)
} }

View File

@@ -62,7 +62,7 @@ func (b *WString) Free() {
// ResizeTo grows the buffer to at least c and returns the new capacity, freeing the // ResizeTo grows the buffer to at least c and returns the new capacity, freeing the
// previous buffer back into pool. // previous buffer back into pool.
func (b *WString) ResizeTo(c uint32) uint32 { func (b *WString) ResizeTo(c uint32) uint32 {
// allready sufficient (or n is 0) // already sufficient (or n is 0)
if c <= b.Cap() { if c <= b.Cap() {
return b.Cap() return b.Cap()
} }

View File

@@ -11,7 +11,6 @@ import (
"net" "net"
"os" "os"
"runtime" "runtime"
"syscall"
"time" "time"
"unsafe" "unsafe"
@@ -20,20 +19,44 @@ import (
"github.com/Microsoft/go-winio/internal/fs" "github.com/Microsoft/go-winio/internal/fs"
) )
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe //sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW //sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo //sys disconnectNamedPipe(pipe windows.Handle) (err error) = DisconnectNamedPipe
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW //sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc //sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) = ntdll.NtCreateNamedPipeFile //sys ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) = ntdll.NtCreateNamedPipeFile
//sys rtlNtStatusToDosError(status ntStatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb //sys rtlNtStatusToDosError(status ntStatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U //sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl //sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl
type PipeConn interface {
net.Conn
Disconnect() error
Flush() error
}
// type aliases for mkwinsyscall code
type (
ntAccessMask = fs.AccessMask
ntFileShareMode = fs.FileShareMode
ntFileCreationDisposition = fs.NTFileCreationDisposition
ntFileOptions = fs.NTCreateOptions
)
type ioStatusBlock struct { type ioStatusBlock struct {
Status, Information uintptr Status, Information uintptr
} }
// typedef struct _OBJECT_ATTRIBUTES {
// ULONG Length;
// HANDLE RootDirectory;
// PUNICODE_STRING ObjectName;
// ULONG Attributes;
// PVOID SecurityDescriptor;
// PVOID SecurityQualityOfService;
// } OBJECT_ATTRIBUTES;
//
// https://learn.microsoft.com/en-us/windows/win32/api/ntdef/ns-ntdef-_object_attributes
type objectAttributes struct { type objectAttributes struct {
Length uintptr Length uintptr
RootDirectory uintptr RootDirectory uintptr
@@ -49,6 +72,17 @@ type unicodeString struct {
Buffer uintptr Buffer uintptr
} }
// typedef struct _SECURITY_DESCRIPTOR {
// BYTE Revision;
// BYTE Sbz1;
// SECURITY_DESCRIPTOR_CONTROL Control;
// PSID Owner;
// PSID Group;
// PACL Sacl;
// PACL Dacl;
// } SECURITY_DESCRIPTOR, *PISECURITY_DESCRIPTOR;
//
// https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-security_descriptor
type securityDescriptor struct { type securityDescriptor struct {
Revision byte Revision byte
Sbz1 byte Sbz1 byte
@@ -80,6 +114,8 @@ type win32Pipe struct {
path string path string
} }
var _ PipeConn = (*win32Pipe)(nil)
type win32MessageBytePipe struct { type win32MessageBytePipe struct {
win32Pipe win32Pipe
writeClosed bool writeClosed bool
@@ -103,6 +139,10 @@ func (f *win32Pipe) SetDeadline(t time.Time) error {
return f.SetWriteDeadline(t) return f.SetWriteDeadline(t)
} }
func (f *win32Pipe) Disconnect() error {
return disconnectNamedPipe(f.win32File.handle)
}
// CloseWrite closes the write side of a message pipe in byte mode. // CloseWrite closes the write side of a message pipe in byte mode.
func (f *win32MessageBytePipe) CloseWrite() error { func (f *win32MessageBytePipe) CloseWrite() error {
if f.writeClosed { if f.writeClosed {
@@ -146,7 +186,7 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
// zero-byte message, ensure that all future Read() calls // zero-byte message, ensure that all future Read() calls
// also return EOF. // also return EOF.
f.readEOF = true f.readEOF = true
} else if err == syscall.ERROR_MORE_DATA { //nolint:errorlint // err is Errno } else if err == windows.ERROR_MORE_DATA { //nolint:errorlint // err is Errno
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode // ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since // and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams. // this package presents all named pipes as byte streams.
@@ -164,21 +204,20 @@ func (s pipeAddress) String() string {
} }
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. // tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string, access fs.AccessMask) (syscall.Handle, error) { func tryDialPipe(ctx context.Context, path *string, access fs.AccessMask, impLevel PipeImpLevel) (windows.Handle, error) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return syscall.Handle(0), ctx.Err() return windows.Handle(0), ctx.Err()
default: default:
wh, err := fs.CreateFile(*path, h, err := fs.CreateFile(*path,
access, access,
0, // mode 0, // mode
nil, // security attributes nil, // security attributes
fs.OPEN_EXISTING, fs.OPEN_EXISTING,
fs.FILE_FLAG_OVERLAPPED|fs.SECURITY_SQOS_PRESENT|fs.SECURITY_ANONYMOUS, fs.FILE_FLAG_OVERLAPPED|fs.SECURITY_SQOS_PRESENT|fs.FileSQSFlag(impLevel),
0, // template file handle 0, // template file handle
) )
h := syscall.Handle(wh)
if err == nil { if err == nil {
return h, nil return h, nil
} }
@@ -214,15 +253,33 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx` // DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout. // cancellation or timeout.
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) { func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
return DialPipeAccess(ctx, path, syscall.GENERIC_READ|syscall.GENERIC_WRITE) return DialPipeAccess(ctx, path, uint32(fs.GENERIC_READ|fs.GENERIC_WRITE))
} }
// PipeImpLevel is an enumeration of impersonation levels that may be set
// when calling DialPipeAccessImpersonation.
type PipeImpLevel uint32
const (
PipeImpLevelAnonymous = PipeImpLevel(fs.SECURITY_ANONYMOUS)
PipeImpLevelIdentification = PipeImpLevel(fs.SECURITY_IDENTIFICATION)
PipeImpLevelImpersonation = PipeImpLevel(fs.SECURITY_IMPERSONATION)
PipeImpLevelDelegation = PipeImpLevel(fs.SECURITY_DELEGATION)
)
// DialPipeAccess attempts to connect to a named pipe by `path` with `access` until `ctx` // DialPipeAccess attempts to connect to a named pipe by `path` with `access` until `ctx`
// cancellation or timeout. // cancellation or timeout.
func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn, error) { func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn, error) {
return DialPipeAccessImpLevel(ctx, path, access, PipeImpLevelAnonymous)
}
// DialPipeAccessImpLevel attempts to connect to a named pipe by `path` with
// `access` at `impLevel` until `ctx` cancellation or timeout. The other
// DialPipe* implementations use PipeImpLevelAnonymous.
func DialPipeAccessImpLevel(ctx context.Context, path string, access uint32, impLevel PipeImpLevel) (net.Conn, error) {
var err error var err error
var h syscall.Handle var h windows.Handle
h, err = tryDialPipe(ctx, &path, fs.AccessMask(access)) h, err = tryDialPipe(ctx, &path, fs.AccessMask(access), impLevel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -235,7 +292,7 @@ func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn,
f, err := makeWin32File(h) f, err := makeWin32File(h)
if err != nil { if err != nil {
syscall.Close(h) windows.Close(h)
return nil, err return nil, err
} }
@@ -255,7 +312,7 @@ type acceptResponse struct {
} }
type win32PipeListener struct { type win32PipeListener struct {
firstHandle syscall.Handle firstHandle windows.Handle
path string path string
config PipeConfig config PipeConfig
acceptCh chan (chan acceptResponse) acceptCh chan (chan acceptResponse)
@@ -263,8 +320,8 @@ type win32PipeListener struct {
doneCh chan int doneCh chan int
} }
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) { func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (windows.Handle, error) {
path16, err := syscall.UTF16FromString(path) path16, err := windows.UTF16FromString(path)
if err != nil { if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
@@ -280,16 +337,20 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
).Err(); err != nil { ).Err(); err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
defer localFree(ntPath.Buffer) defer windows.LocalFree(windows.Handle(ntPath.Buffer)) //nolint:errcheck
oa.ObjectName = &ntPath oa.ObjectName = &ntPath
oa.Attributes = windows.OBJ_CASE_INSENSITIVE oa.Attributes = windows.OBJ_CASE_INSENSITIVE
// The security descriptor is only needed for the first pipe. // The security descriptor is only needed for the first pipe.
if first { if first {
if sd != nil { if sd != nil {
//todo: does `sdb` need to be allocated on the heap, or can go allocate it?
l := uint32(len(sd)) l := uint32(len(sd))
sdb := localAlloc(0, l) sdb, err := windows.LocalAlloc(0, l)
defer localFree(sdb) if err != nil {
return 0, fmt.Errorf("LocalAlloc for security descriptor with of length %d: %w", l, err)
}
defer windows.LocalFree(windows.Handle(sdb)) //nolint:errcheck
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd) copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb)) oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
} else { } else {
@@ -298,7 +359,7 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
return 0, fmt.Errorf("getting default named pipe ACL: %w", err) return 0, fmt.Errorf("getting default named pipe ACL: %w", err)
} }
defer localFree(dacl) defer windows.LocalFree(windows.Handle(dacl)) //nolint:errcheck
sdb := &securityDescriptor{ sdb := &securityDescriptor{
Revision: 1, Revision: 1,
@@ -314,27 +375,27 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
typ |= windows.FILE_PIPE_MESSAGE_TYPE typ |= windows.FILE_PIPE_MESSAGE_TYPE
} }
disposition := uint32(windows.FILE_OPEN) disposition := fs.FILE_OPEN
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE) access := fs.GENERIC_READ | fs.GENERIC_WRITE | fs.SYNCHRONIZE
if first { if first {
disposition = windows.FILE_CREATE disposition = fs.FILE_CREATE
// By not asking for read or write access, the named pipe file system // By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking // will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false. // client connections until the next call with first == false.
access = syscall.SYNCHRONIZE access = fs.SYNCHRONIZE
} }
timeout := int64(-50 * 10000) // 50ms timeout := int64(-50 * 10000) // 50ms
var ( var (
h syscall.Handle h windows.Handle
iosb ioStatusBlock iosb ioStatusBlock
) )
err = ntCreateNamedPipeFile(&h, err = ntCreateNamedPipeFile(&h,
access, access,
&oa, &oa,
&iosb, &iosb,
syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, fs.FILE_SHARE_READ|fs.FILE_SHARE_WRITE,
disposition, disposition,
0, 0,
typ, typ,
@@ -359,7 +420,7 @@ func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
} }
f, err := makeWin32File(h) f, err := makeWin32File(h)
if err != nil { if err != nil {
syscall.Close(h) windows.Close(h)
return nil, err return nil, err
} }
return f, nil return f, nil
@@ -418,7 +479,7 @@ func (l *win32PipeListener) listenerRoutine() {
closed = err == ErrPipeListenerClosed //nolint:errorlint // err is Errno closed = err == ErrPipeListenerClosed //nolint:errorlint // err is Errno
} }
} }
syscall.Close(l.firstHandle) windows.Close(l.firstHandle)
l.firstHandle = 0 l.firstHandle = 0
// Notify Close() and Accept() callers that the handle has been closed. // Notify Close() and Accept() callers that the handle has been closed.
close(l.doneCh) close(l.doneCh)

View File

@@ -9,7 +9,6 @@ import (
"fmt" "fmt"
"runtime" "runtime"
"sync" "sync"
"syscall"
"unicode/utf16" "unicode/utf16"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@@ -18,8 +17,8 @@ import (
//sys adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) [true] = advapi32.AdjustTokenPrivileges //sys adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) [true] = advapi32.AdjustTokenPrivileges
//sys impersonateSelf(level uint32) (err error) = advapi32.ImpersonateSelf //sys impersonateSelf(level uint32) (err error) = advapi32.ImpersonateSelf
//sys revertToSelf() (err error) = advapi32.RevertToSelf //sys revertToSelf() (err error) = advapi32.RevertToSelf
//sys openThreadToken(thread syscall.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) = advapi32.OpenThreadToken //sys openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) = advapi32.OpenThreadToken
//sys getCurrentThread() (h syscall.Handle) = GetCurrentThread //sys getCurrentThread() (h windows.Handle) = GetCurrentThread
//sys lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) = advapi32.LookupPrivilegeValueW //sys lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) = advapi32.LookupPrivilegeValueW
//sys lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) = advapi32.LookupPrivilegeNameW //sys lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) = advapi32.LookupPrivilegeNameW
//sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW //sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW
@@ -29,7 +28,7 @@ const (
SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED
//revive:disable-next-line:var-naming ALL_CAPS //revive:disable-next-line:var-naming ALL_CAPS
ERROR_NOT_ALL_ASSIGNED syscall.Errno = windows.ERROR_NOT_ALL_ASSIGNED ERROR_NOT_ALL_ASSIGNED windows.Errno = windows.ERROR_NOT_ALL_ASSIGNED
SeBackupPrivilege = "SeBackupPrivilege" SeBackupPrivilege = "SeBackupPrivilege"
SeRestorePrivilege = "SeRestorePrivilege" SeRestorePrivilege = "SeRestorePrivilege"
@@ -177,7 +176,7 @@ func newThreadToken() (windows.Token, error) {
} }
var token windows.Token var token windows.Token
err = openThreadToken(getCurrentThread(), syscall.TOKEN_ADJUST_PRIVILEGES|syscall.TOKEN_QUERY, false, &token) err = openThreadToken(getCurrentThread(), windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, false, &token)
if err != nil { if err != nil {
rerr := revertToSelf() rerr := revertToSelf()
if rerr != nil { if rerr != nil {

View File

@@ -5,7 +5,7 @@ package winio
import ( import (
"errors" "errors"
"syscall" "fmt"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@@ -15,10 +15,6 @@ import (
//sys lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountSidW //sys lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountSidW
//sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW //sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW
//sys convertStringSidToSid(str *uint16, sid **byte) (err error) = advapi32.ConvertStringSidToSidW //sys convertStringSidToSid(str *uint16, sid **byte) (err error) = advapi32.ConvertStringSidToSidW
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
//sys convertSecurityDescriptorToStringSecurityDescriptor(sd *byte, revision uint32, secInfo uint32, sddl **uint16, sddlSize *uint32) (err error) = advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW
//sys localFree(mem uintptr) = LocalFree
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
type AccountLookupError struct { type AccountLookupError struct {
Name string Name string
@@ -64,7 +60,7 @@ func LookupSidByName(name string) (sid string, err error) {
var sidSize, sidNameUse, refDomainSize uint32 var sidSize, sidNameUse, refDomainSize uint32
err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse) err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse)
if err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
return "", &AccountLookupError{name, err} return "", &AccountLookupError{name, err}
} }
sidBuffer := make([]byte, sidSize) sidBuffer := make([]byte, sidSize)
@@ -78,8 +74,8 @@ func LookupSidByName(name string) (sid string, err error) {
if err != nil { if err != nil {
return "", &AccountLookupError{name, err} return "", &AccountLookupError{name, err}
} }
sid = syscall.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(strBuffer))[:]) sid = windows.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(strBuffer))[:])
localFree(uintptr(unsafe.Pointer(strBuffer))) _, _ = windows.LocalFree(windows.Handle(unsafe.Pointer(strBuffer)))
return sid, nil return sid, nil
} }
@@ -100,7 +96,7 @@ func LookupNameBySid(sid string) (name string, err error) {
if err = convertStringSidToSid(sidBuffer, &sidPtr); err != nil { if err = convertStringSidToSid(sidBuffer, &sidPtr); err != nil {
return "", &AccountLookupError{sid, err} return "", &AccountLookupError{sid, err}
} }
defer localFree(uintptr(unsafe.Pointer(sidPtr))) defer windows.LocalFree(windows.Handle(unsafe.Pointer(sidPtr))) //nolint:errcheck
var nameSize, refDomainSize, sidNameUse uint32 var nameSize, refDomainSize, sidNameUse uint32
err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse) err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse)
@@ -120,25 +116,18 @@ func LookupNameBySid(sid string) (name string, err error) {
} }
func SddlToSecurityDescriptor(sddl string) ([]byte, error) { func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
var sdBuffer uintptr sd, err := windows.SecurityDescriptorFromString(sddl)
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
if err != nil { if err != nil {
return nil, &SddlConversionError{sddl, err} return nil, &SddlConversionError{Sddl: sddl, Err: err}
} }
defer localFree(sdBuffer) b := unsafe.Slice((*byte)(unsafe.Pointer(sd)), sd.Length())
sd := make([]byte, getSecurityDescriptorLength(sdBuffer)) return b, nil
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
return sd, nil
} }
func SecurityDescriptorToSddl(sd []byte) (string, error) { func SecurityDescriptorToSddl(sd []byte) (string, error) {
var sddl *uint16 if l := int(unsafe.Sizeof(windows.SECURITY_DESCRIPTOR{})); len(sd) < l {
// The returned string length seems to include an arbitrary number of terminating NULs. return "", fmt.Errorf("SecurityDescriptor (%d) smaller than expected (%d): %w", len(sd), l, windows.ERROR_INCORRECT_SIZE)
// Don't use it.
err := convertSecurityDescriptorToStringSecurityDescriptor(&sd[0], 1, 0xff, &sddl, nil)
if err != nil {
return "", err
} }
defer localFree(uintptr(unsafe.Pointer(sddl))) s := (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(&sd[0]))
return syscall.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(sddl))[:]), nil return s.String(), nil
} }

View File

@@ -1,5 +0,0 @@
//go:build tools
package winio
import _ "golang.org/x/tools/cmd/stringer"

View File

@@ -33,9 +33,6 @@ func errnoErr(e syscall.Errno) error {
case errnoERROR_IO_PENDING: case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING return errERROR_IO_PENDING
} }
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e return e
} }
@@ -45,38 +42,34 @@ var (
modntdll = windows.NewLazySystemDLL("ntdll.dll") modntdll = windows.NewLazySystemDLL("ntdll.dll")
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges") procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW") procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW")
procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW") procConvertStringSidToSidW = modadvapi32.NewProc("ConvertStringSidToSidW")
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW") procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf")
procConvertStringSidToSidW = modadvapi32.NewProc("ConvertStringSidToSidW") procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength") procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf") procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW")
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW") procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW")
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW") procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW") procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW") procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW") procBackupRead = modkernel32.NewProc("BackupRead")
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken") procBackupWrite = modkernel32.NewProc("BackupWrite")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf") procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procBackupRead = modkernel32.NewProc("BackupRead") procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
procBackupWrite = modkernel32.NewProc("BackupWrite") procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procCancelIoEx = modkernel32.NewProc("CancelIoEx") procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") procDisconnectNamedPipe = modkernel32.NewProc("DisconnectNamedPipe")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort") procGetCurrentThread = modkernel32.NewProc("GetCurrentThread")
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
procGetCurrentThread = modkernel32.NewProc("GetCurrentThread") procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW") procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo") procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
procLocalAlloc = modkernel32.NewProc("LocalAlloc") procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
procLocalFree = modkernel32.NewProc("LocalFree") procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes") procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile") procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
) )
func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) { func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) {
@@ -84,7 +77,7 @@ func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, ou
if releaseAll { if releaseAll {
_p0 = 1 _p0 = 1
} }
r0, _, e1 := syscall.Syscall6(procAdjustTokenPrivileges.Addr(), 6, uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize))) r0, _, e1 := syscall.SyscallN(procAdjustTokenPrivileges.Addr(), uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize)))
success = r0 != 0 success = r0 != 0
if true { if true {
err = errnoErr(e1) err = errnoErr(e1)
@@ -92,33 +85,8 @@ func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, ou
return return
} }
func convertSecurityDescriptorToStringSecurityDescriptor(sd *byte, revision uint32, secInfo uint32, sddl **uint16, sddlSize *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertSecurityDescriptorToStringSecurityDescriptorW.Addr(), 5, uintptr(unsafe.Pointer(sd)), uintptr(revision), uintptr(secInfo), uintptr(unsafe.Pointer(sddl)), uintptr(unsafe.Pointer(sddlSize)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func convertSidToStringSid(sid *byte, str **uint16) (err error) { func convertSidToStringSid(sid *byte, str **uint16) (err error) {
r1, _, e1 := syscall.Syscall(procConvertSidToStringSidW.Addr(), 2, uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(str)), 0) r1, _, e1 := syscall.SyscallN(procConvertSidToStringSidW.Addr(), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(str)))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(str)
if err != nil {
return
}
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
}
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
@@ -126,21 +94,15 @@ func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision
} }
func convertStringSidToSid(str *uint16, sid **byte) (err error) { func convertStringSidToSid(str *uint16, sid **byte) (err error) {
r1, _, e1 := syscall.Syscall(procConvertStringSidToSidW.Addr(), 2, uintptr(unsafe.Pointer(str)), uintptr(unsafe.Pointer(sid)), 0) r1, _, e1 := syscall.SyscallN(procConvertStringSidToSidW.Addr(), uintptr(unsafe.Pointer(str)), uintptr(unsafe.Pointer(sid)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
len = uint32(r0)
return
}
func impersonateSelf(level uint32) (err error) { func impersonateSelf(level uint32) (err error) {
r1, _, e1 := syscall.Syscall(procImpersonateSelf.Addr(), 1, uintptr(level), 0, 0) r1, _, e1 := syscall.SyscallN(procImpersonateSelf.Addr(), uintptr(level))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
@@ -157,7 +119,7 @@ func lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSiz
} }
func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) { func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procLookupAccountNameW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0) r1, _, e1 := syscall.SyscallN(procLookupAccountNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
@@ -165,7 +127,7 @@ func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidS
} }
func lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) { func lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procLookupAccountSidW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0) r1, _, e1 := syscall.SyscallN(procLookupAccountSidW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
@@ -182,7 +144,7 @@ func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16,
} }
func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) { func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeDisplayNameW.Addr(), 5, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageId)), 0) r1, _, e1 := syscall.SyscallN(procLookupPrivilegeDisplayNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageId)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
@@ -199,7 +161,7 @@ func lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *
} }
func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) { func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeNameW.Addr(), 4, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), 0, 0) r1, _, e1 := syscall.SyscallN(procLookupPrivilegeNameW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
@@ -221,19 +183,19 @@ func lookupPrivilegeValue(systemName string, name string, luid *uint64) (err err
} }
func _lookupPrivilegeValue(systemName *uint16, name *uint16, luid *uint64) (err error) { func _lookupPrivilegeValue(systemName *uint16, name *uint16, luid *uint64) (err error) {
r1, _, e1 := syscall.Syscall(procLookupPrivilegeValueW.Addr(), 3, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(luid))) r1, _, e1 := syscall.SyscallN(procLookupPrivilegeValueW.Addr(), uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(luid)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func openThreadToken(thread syscall.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) { func openThreadToken(thread windows.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) {
var _p0 uint32 var _p0 uint32
if openAsSelf { if openAsSelf {
_p0 = 1 _p0 = 1
} }
r1, _, e1 := syscall.Syscall6(procOpenThreadToken.Addr(), 4, uintptr(thread), uintptr(accessMask), uintptr(_p0), uintptr(unsafe.Pointer(token)), 0, 0) r1, _, e1 := syscall.SyscallN(procOpenThreadToken.Addr(), uintptr(thread), uintptr(accessMask), uintptr(_p0), uintptr(unsafe.Pointer(token)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
@@ -241,14 +203,14 @@ func openThreadToken(thread syscall.Handle, accessMask uint32, openAsSelf bool,
} }
func revertToSelf() (err error) { func revertToSelf() (err error) {
r1, _, e1 := syscall.Syscall(procRevertToSelf.Addr(), 0, 0, 0, 0) r1, _, e1 := syscall.SyscallN(procRevertToSelf.Addr())
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) { func backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
var _p0 *byte var _p0 *byte
if len(b) > 0 { if len(b) > 0 {
_p0 = &b[0] _p0 = &b[0]
@@ -261,14 +223,14 @@ func backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, proce
if processSecurity { if processSecurity {
_p2 = 1 _p2 = 1
} }
r1, _, e1 := syscall.Syscall9(procBackupRead.Addr(), 7, uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesRead)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)), 0, 0) r1, _, e1 := syscall.SyscallN(procBackupRead.Addr(), uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesRead)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func backupWrite(h syscall.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) { func backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
var _p0 *byte var _p0 *byte
if len(b) > 0 { if len(b) > 0 {
_p0 = &b[0] _p0 = &b[0]
@@ -281,39 +243,39 @@ func backupWrite(h syscall.Handle, b []byte, bytesWritten *uint32, abort bool, p
if processSecurity { if processSecurity {
_p2 = 1 _p2 = 1
} }
r1, _, e1 := syscall.Syscall9(procBackupWrite.Addr(), 7, uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesWritten)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)), 0, 0) r1, _, e1 := syscall.SyscallN(procBackupWrite.Addr(), uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesWritten)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) { func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0) r1, _, e1 := syscall.SyscallN(procCancelIoEx.Addr(), uintptr(file), uintptr(unsafe.Pointer(o)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) { func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0) r1, _, e1 := syscall.SyscallN(procConnectNamedPipe.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(o)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) { func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0) r0, _, e1 := syscall.SyscallN(procCreateIoCompletionPort.Addr(), uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount))
newport = syscall.Handle(r0) newport = windows.Handle(r0)
if newport == 0 { if newport == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) { func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
var _p0 *uint16 var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name) _p0, err = syscall.UTF16PtrFromString(name)
if err != nil { if err != nil {
@@ -322,96 +284,93 @@ func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances ui
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa) return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
} }
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) { func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0) r0, _, e1 := syscall.SyscallN(procCreateNamedPipeW.Addr(), uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)))
handle = syscall.Handle(r0) handle = windows.Handle(r0)
if handle == syscall.InvalidHandle { if handle == windows.InvalidHandle {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func getCurrentThread() (h syscall.Handle) { func disconnectNamedPipe(pipe windows.Handle) (err error) {
r0, _, _ := syscall.Syscall(procGetCurrentThread.Addr(), 0, 0, 0, 0) r1, _, e1 := syscall.SyscallN(procDisconnectNamedPipe.Addr(), uintptr(pipe))
h = syscall.Handle(r0)
return
}
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) { func getCurrentThread() (h windows.Handle) {
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0) r0, _, _ := syscall.SyscallN(procGetCurrentThread.Addr())
h = windows.Handle(r0)
return
}
func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.SyscallN(procGetNamedPipeHandleStateW.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) { func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0) r1, _, e1 := syscall.SyscallN(procGetNamedPipeInfo.Addr(), uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func localAlloc(uFlags uint32, length uint32) (ptr uintptr) { func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0) r1, _, e1 := syscall.SyscallN(procGetQueuedCompletionStatus.Addr(), uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout))
ptr = uintptr(r0)
return
}
func localFree(mem uintptr) {
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
return
}
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }
return return
} }
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) { func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0) r1, _, e1 := syscall.SyscallN(procSetFileCompletionNotificationModes.Addr(), uintptr(h), uintptr(flags))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func ntCreateNamedPipeFile(pipe *windows.Handle, access ntAccessMask, oa *objectAttributes, iosb *ioStatusBlock, share ntFileShareMode, disposition ntFileCreationDisposition, options ntFileOptions, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) {
r0, _, _ := syscall.SyscallN(procNtCreateNamedPipeFile.Addr(), uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)))
status = ntStatus(r0) status = ntStatus(r0)
return return
} }
func rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) { func rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) {
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0) r0, _, _ := syscall.SyscallN(procRtlDefaultNpAcl.Addr(), uintptr(unsafe.Pointer(dacl)))
status = ntStatus(r0) status = ntStatus(r0)
return return
} }
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) { func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) {
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0) r0, _, _ := syscall.SyscallN(procRtlDosPathNameToNtPathName_U.Addr(), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved))
status = ntStatus(r0) status = ntStatus(r0)
return return
} }
func rtlNtStatusToDosError(status ntStatus) (winerr error) { func rtlNtStatusToDosError(status ntStatus) (winerr error) {
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0) r0, _, _ := syscall.SyscallN(procRtlNtStatusToDosErrorNoTeb.Addr(), uintptr(status))
if r0 != 0 { if r0 != 0 {
winerr = syscall.Errno(r0) winerr = syscall.Errno(r0)
} }
return return
} }
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) { func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
var _p0 uint32 var _p0 uint32
if wait { if wait {
_p0 = 1 _p0 = 1
} }
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0) r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)))
if r1 == 0 { if r1 == 0 {
err = errnoErr(e1) err = errnoErr(e1)
} }

View File

@@ -49,16 +49,16 @@ func ShiftNBytesLeft(dst, x []byte, n int) {
dst = append(dst, make([]byte, n/8)...) dst = append(dst, make([]byte, n/8)...)
} }
// XorBytesMut assumes equal input length, replaces X with X XOR Y // XorBytesMut replaces X with X XOR Y. len(X) must be >= len(Y).
func XorBytesMut(X, Y []byte) { func XorBytesMut(X, Y []byte) {
for i := 0; i < len(X); i++ { for i := 0; i < len(Y); i++ {
X[i] ^= Y[i] X[i] ^= Y[i]
} }
} }
// XorBytes assumes equal input length, puts X XOR Y into Z // XorBytes puts X XOR Y into Z. len(Z) and len(X) must be >= len(Y).
func XorBytes(Z, X, Y []byte) { func XorBytes(Z, X, Y []byte) {
for i := 0; i < len(X); i++ { for i := 0; i < len(Y); i++ {
Z[i] = X[i] ^ Y[i] Z[i] = X[i] ^ Y[i]
} }
} }

View File

@@ -18,8 +18,9 @@ import (
"crypto/cipher" "crypto/cipher"
"crypto/subtle" "crypto/subtle"
"errors" "errors"
"github.com/ProtonMail/go-crypto/internal/byteutil"
"math/bits" "math/bits"
"github.com/ProtonMail/go-crypto/internal/byteutil"
) )
type ocb struct { type ocb struct {
@@ -108,8 +109,10 @@ func (o *ocb) Seal(dst, nonce, plaintext, adata []byte) []byte {
if len(nonce) > o.nonceSize { if len(nonce) > o.nonceSize {
panic("crypto/ocb: Incorrect nonce length given to OCB") panic("crypto/ocb: Incorrect nonce length given to OCB")
} }
ret, out := byteutil.SliceForAppend(dst, len(plaintext)+o.tagSize) sep := len(plaintext)
o.crypt(enc, out, nonce, adata, plaintext) ret, out := byteutil.SliceForAppend(dst, sep+o.tagSize)
tag := o.crypt(enc, out[:sep], nonce, adata, plaintext)
copy(out[sep:], tag)
return ret return ret
} }
@@ -121,12 +124,10 @@ func (o *ocb) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) {
return nil, ocbError("Ciphertext shorter than tag length") return nil, ocbError("Ciphertext shorter than tag length")
} }
sep := len(ciphertext) - o.tagSize sep := len(ciphertext) - o.tagSize
ret, out := byteutil.SliceForAppend(dst, len(ciphertext)) ret, out := byteutil.SliceForAppend(dst, sep)
ciphertextData := ciphertext[:sep] ciphertextData := ciphertext[:sep]
tag := ciphertext[sep:] tag := o.crypt(dec, out, nonce, adata, ciphertextData)
o.crypt(dec, out, nonce, adata, ciphertextData) if subtle.ConstantTimeCompare(tag, ciphertext[sep:]) == 1 {
if subtle.ConstantTimeCompare(ret[sep:], tag) == 1 {
ret = ret[:sep]
return ret, nil return ret, nil
} }
for i := range out { for i := range out {
@@ -136,7 +137,8 @@ func (o *ocb) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) {
} }
// On instruction enc (resp. dec), crypt is the encrypt (resp. decrypt) // On instruction enc (resp. dec), crypt is the encrypt (resp. decrypt)
// function. It returns the resulting plain/ciphertext with the tag appended. // function. It writes the resulting plain/ciphertext into Y and returns
// the tag.
func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte { func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
// //
// Consider X as a sequence of 128-bit blocks // Consider X as a sequence of 128-bit blocks
@@ -153,7 +155,7 @@ func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
truncatedNonce := make([]byte, len(nonce)) truncatedNonce := make([]byte, len(nonce))
copy(truncatedNonce, nonce) copy(truncatedNonce, nonce)
truncatedNonce[len(truncatedNonce)-1] &= 192 truncatedNonce[len(truncatedNonce)-1] &= 192
Ktop := make([]byte, blockSize) var Ktop []byte
if bytes.Equal(truncatedNonce, o.reusableKtop.noncePrefix) { if bytes.Equal(truncatedNonce, o.reusableKtop.noncePrefix) {
Ktop = o.reusableKtop.Ktop Ktop = o.reusableKtop.Ktop
} else { } else {
@@ -193,13 +195,14 @@ func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
byteutil.XorBytesMut(offset, o.mask.L[bits.TrailingZeros(uint(i+1))]) byteutil.XorBytesMut(offset, o.mask.L[bits.TrailingZeros(uint(i+1))])
blockX := X[i*blockSize : (i+1)*blockSize] blockX := X[i*blockSize : (i+1)*blockSize]
blockY := Y[i*blockSize : (i+1)*blockSize] blockY := Y[i*blockSize : (i+1)*blockSize]
byteutil.XorBytes(blockY, blockX, offset)
switch instruction { switch instruction {
case enc: case enc:
byteutil.XorBytesMut(checksum, blockX)
byteutil.XorBytes(blockY, blockX, offset)
o.block.Encrypt(blockY, blockY) o.block.Encrypt(blockY, blockY)
byteutil.XorBytesMut(blockY, offset) byteutil.XorBytesMut(blockY, offset)
byteutil.XorBytesMut(checksum, blockX)
case dec: case dec:
byteutil.XorBytes(blockY, blockX, offset)
o.block.Decrypt(blockY, blockY) o.block.Decrypt(blockY, blockY)
byteutil.XorBytesMut(blockY, offset) byteutil.XorBytesMut(blockY, offset)
byteutil.XorBytesMut(checksum, blockY) byteutil.XorBytesMut(checksum, blockY)
@@ -215,31 +218,24 @@ func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
o.block.Encrypt(pad, offset) o.block.Encrypt(pad, offset)
chunkX := X[blockSize*m:] chunkX := X[blockSize*m:]
chunkY := Y[blockSize*m : len(X)] chunkY := Y[blockSize*m : len(X)]
byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
// P_* || bit(1) || zeroes(127) - len(P_*)
switch instruction { switch instruction {
case enc: case enc:
paddedY := append(chunkX, byte(128)) byteutil.XorBytesMut(checksum, chunkX)
paddedY = append(paddedY, make([]byte, blockSize-len(chunkX)-1)...) checksum[len(chunkX)] ^= 128
byteutil.XorBytesMut(checksum, paddedY) byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
// P_* || bit(1) || zeroes(127) - len(P_*)
case dec: case dec:
paddedX := append(chunkY, byte(128)) byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
paddedX = append(paddedX, make([]byte, blockSize-len(chunkY)-1)...) // P_* || bit(1) || zeroes(127) - len(P_*)
byteutil.XorBytesMut(checksum, paddedX) byteutil.XorBytesMut(checksum, chunkY)
checksum[len(chunkY)] ^= 128
} }
byteutil.XorBytes(tag, checksum, offset)
byteutil.XorBytesMut(tag, o.mask.lDol)
o.block.Encrypt(tag, tag)
byteutil.XorBytesMut(tag, o.hash(adata))
copy(Y[blockSize*m+len(chunkY):], tag[:o.tagSize])
} else {
byteutil.XorBytes(tag, checksum, offset)
byteutil.XorBytesMut(tag, o.mask.lDol)
o.block.Encrypt(tag, tag)
byteutil.XorBytesMut(tag, o.hash(adata))
copy(Y[blockSize*m:], tag[:o.tagSize])
} }
return Y byteutil.XorBytes(tag, checksum, offset)
byteutil.XorBytesMut(tag, o.mask.lDol)
o.block.Encrypt(tag, tag)
byteutil.XorBytesMut(tag, o.hash(adata))
return tag[:o.tagSize]
} }
// This hash function is used to compute the tag. Per design, on empty input it // This hash function is used to compute the tag. Per design, on empty input it

View File

@@ -23,7 +23,7 @@ import (
// Headers // Headers
// //
// base64-encoded Bytes // base64-encoded Bytes
// '=' base64 encoded checksum // '=' base64 encoded checksum (optional) not checked anymore
// -----END Type----- // -----END Type-----
// //
// where Headers is a possibly empty sequence of Key: Value lines. // where Headers is a possibly empty sequence of Key: Value lines.
@@ -40,36 +40,15 @@ type Block struct {
var ArmorCorrupt error = errors.StructuralError("armor invalid") var ArmorCorrupt error = errors.StructuralError("armor invalid")
const crc24Init = 0xb704ce
const crc24Poly = 0x1864cfb
const crc24Mask = 0xffffff
// crc24 calculates the OpenPGP checksum as specified in RFC 4880, section 6.1
func crc24(crc uint32, d []byte) uint32 {
for _, b := range d {
crc ^= uint32(b) << 16
for i := 0; i < 8; i++ {
crc <<= 1
if crc&0x1000000 != 0 {
crc ^= crc24Poly
}
}
}
return crc
}
var armorStart = []byte("-----BEGIN ") var armorStart = []byte("-----BEGIN ")
var armorEnd = []byte("-----END ") var armorEnd = []byte("-----END ")
var armorEndOfLine = []byte("-----") var armorEndOfLine = []byte("-----")
// lineReader wraps a line based reader. It watches for the end of an armor // lineReader wraps a line based reader. It watches for the end of an armor block
// block and records the expected CRC value.
type lineReader struct { type lineReader struct {
in *bufio.Reader in *bufio.Reader
buf []byte buf []byte
eof bool eof bool
crc uint32
crcSet bool
} }
func (l *lineReader) Read(p []byte) (n int, err error) { func (l *lineReader) Read(p []byte) (n int, err error) {
@@ -98,26 +77,9 @@ func (l *lineReader) Read(p []byte) (n int, err error) {
if len(line) == 5 && line[0] == '=' { if len(line) == 5 && line[0] == '=' {
// This is the checksum line // This is the checksum line
var expectedBytes [3]byte // Don't check the checksum
var m int
m, err = base64.StdEncoding.Decode(expectedBytes[0:], line[1:])
if m != 3 || err != nil {
return
}
l.crc = uint32(expectedBytes[0])<<16 |
uint32(expectedBytes[1])<<8 |
uint32(expectedBytes[2])
line, _, err = l.in.ReadLine()
if err != nil && err != io.EOF {
return
}
if !bytes.HasPrefix(line, armorEnd) {
return 0, ArmorCorrupt
}
l.eof = true l.eof = true
l.crcSet = true
return 0, io.EOF return 0, io.EOF
} }
@@ -138,23 +100,14 @@ func (l *lineReader) Read(p []byte) (n int, err error) {
return return
} }
// openpgpReader passes Read calls to the underlying base64 decoder, but keeps // openpgpReader passes Read calls to the underlying base64 decoder.
// a running CRC of the resulting data and checks the CRC against the value
// found by the lineReader at EOF.
type openpgpReader struct { type openpgpReader struct {
lReader *lineReader lReader *lineReader
b64Reader io.Reader b64Reader io.Reader
currentCRC uint32
} }
func (r *openpgpReader) Read(p []byte) (n int, err error) { func (r *openpgpReader) Read(p []byte) (n int, err error) {
n, err = r.b64Reader.Read(p) n, err = r.b64Reader.Read(p)
r.currentCRC = crc24(r.currentCRC, p[:n])
if err == io.EOF && r.lReader.crcSet && r.lReader.crc != uint32(r.currentCRC&crc24Mask) {
return 0, ArmorCorrupt
}
return return
} }
@@ -222,7 +175,6 @@ TryNextBlock:
} }
p.lReader.in = r p.lReader.in = r
p.oReader.currentCRC = crc24Init
p.oReader.lReader = &p.lReader p.oReader.lReader = &p.lReader
p.oReader.b64Reader = base64.NewDecoder(base64.StdEncoding, &p.lReader) p.oReader.b64Reader = base64.NewDecoder(base64.StdEncoding, &p.lReader)
p.Body = &p.oReader p.Body = &p.oReader

View File

@@ -7,6 +7,7 @@ package armor
import ( import (
"encoding/base64" "encoding/base64"
"io" "io"
"sort"
) )
var armorHeaderSep = []byte(": ") var armorHeaderSep = []byte(": ")
@@ -14,6 +15,23 @@ var blockEnd = []byte("\n=")
var newline = []byte("\n") var newline = []byte("\n")
var armorEndOfLineOut = []byte("-----\n") var armorEndOfLineOut = []byte("-----\n")
const crc24Init = 0xb704ce
const crc24Poly = 0x1864cfb
// crc24 calculates the OpenPGP checksum as specified in RFC 4880, section 6.1
func crc24(crc uint32, d []byte) uint32 {
for _, b := range d {
crc ^= uint32(b) << 16
for i := 0; i < 8; i++ {
crc <<= 1
if crc&0x1000000 != 0 {
crc ^= crc24Poly
}
}
}
return crc
}
// writeSlices writes its arguments to the given Writer. // writeSlices writes its arguments to the given Writer.
func writeSlices(out io.Writer, slices ...[]byte) (err error) { func writeSlices(out io.Writer, slices ...[]byte) (err error) {
for _, s := range slices { for _, s := range slices {
@@ -99,15 +117,18 @@ func (l *lineBreaker) Close() (err error) {
// //
// encoding -> base64 encoder -> lineBreaker -> out // encoding -> base64 encoder -> lineBreaker -> out
type encoding struct { type encoding struct {
out io.Writer out io.Writer
breaker *lineBreaker breaker *lineBreaker
b64 io.WriteCloser b64 io.WriteCloser
crc uint32 crc uint32
blockType []byte crcEnabled bool
blockType []byte
} }
func (e *encoding) Write(data []byte) (n int, err error) { func (e *encoding) Write(data []byte) (n int, err error) {
e.crc = crc24(e.crc, data) if e.crcEnabled {
e.crc = crc24(e.crc, data)
}
return e.b64.Write(data) return e.b64.Write(data)
} }
@@ -118,28 +139,36 @@ func (e *encoding) Close() (err error) {
} }
e.breaker.Close() e.breaker.Close()
var checksumBytes [3]byte if e.crcEnabled {
checksumBytes[0] = byte(e.crc >> 16) var checksumBytes [3]byte
checksumBytes[1] = byte(e.crc >> 8) checksumBytes[0] = byte(e.crc >> 16)
checksumBytes[2] = byte(e.crc) checksumBytes[1] = byte(e.crc >> 8)
checksumBytes[2] = byte(e.crc)
var b64ChecksumBytes [4]byte var b64ChecksumBytes [4]byte
base64.StdEncoding.Encode(b64ChecksumBytes[:], checksumBytes[:]) base64.StdEncoding.Encode(b64ChecksumBytes[:], checksumBytes[:])
return writeSlices(e.out, blockEnd, b64ChecksumBytes[:], newline, armorEnd, e.blockType, armorEndOfLine) return writeSlices(e.out, blockEnd, b64ChecksumBytes[:], newline, armorEnd, e.blockType, armorEndOfLine)
}
return writeSlices(e.out, newline, armorEnd, e.blockType, armorEndOfLine)
} }
// Encode returns a WriteCloser which will encode the data written to it in func encode(out io.Writer, blockType string, headers map[string]string, checksum bool) (w io.WriteCloser, err error) {
// OpenPGP armor.
func Encode(out io.Writer, blockType string, headers map[string]string) (w io.WriteCloser, err error) {
bType := []byte(blockType) bType := []byte(blockType)
err = writeSlices(out, armorStart, bType, armorEndOfLineOut) err = writeSlices(out, armorStart, bType, armorEndOfLineOut)
if err != nil { if err != nil {
return return
} }
for k, v := range headers { keys := make([]string, len(headers))
err = writeSlices(out, []byte(k), armorHeaderSep, []byte(v), newline) i := 0
for k := range headers {
keys[i] = k
i++
}
sort.Strings(keys)
for _, k := range keys {
err = writeSlices(out, []byte(k), armorHeaderSep, []byte(headers[k]), newline)
if err != nil { if err != nil {
return return
} }
@@ -151,11 +180,27 @@ func Encode(out io.Writer, blockType string, headers map[string]string) (w io.Wr
} }
e := &encoding{ e := &encoding{
out: out, out: out,
breaker: newLineBreaker(out, 64), breaker: newLineBreaker(out, 64),
crc: crc24Init, blockType: bType,
blockType: bType, crc: crc24Init,
crcEnabled: checksum,
} }
e.b64 = base64.NewEncoder(base64.StdEncoding, e.breaker) e.b64 = base64.NewEncoder(base64.StdEncoding, e.breaker)
return e, nil return e, nil
} }
// Encode returns a WriteCloser which will encode the data written to it in
// OpenPGP armor.
func Encode(out io.Writer, blockType string, headers map[string]string) (w io.WriteCloser, err error) {
return encode(out, blockType, headers, true)
}
// EncodeWithChecksumOption returns a WriteCloser which will encode the data written to it in
// OpenPGP armor and provides the option to include a checksum.
// When forming ASCII Armor, the CRC24 footer SHOULD NOT be generated,
// unless interoperability with implementations that require the CRC24 footer
// to be present is a concern.
func EncodeWithChecksumOption(out io.Writer, blockType string, headers map[string]string, doChecksum bool) (w io.WriteCloser, err error) {
return encode(out, blockType, headers, doChecksum)
}

View File

@@ -30,8 +30,12 @@ func writeCanonical(cw io.Writer, buf []byte, s *int) (int, error) {
if c == '\r' { if c == '\r' {
*s = 1 *s = 1
} else if c == '\n' { } else if c == '\n' {
cw.Write(buf[start:i]) if _, err := cw.Write(buf[start:i]); err != nil {
cw.Write(newline) return 0, err
}
if _, err := cw.Write(newline); err != nil {
return 0, err
}
start = i + 1 start = i + 1
} }
case 1: case 1:
@@ -39,7 +43,9 @@ func writeCanonical(cw io.Writer, buf []byte, s *int) (int, error) {
} }
} }
cw.Write(buf[start:]) if _, err := cw.Write(buf[start:]); err != nil {
return 0, err
}
return len(buf), nil return len(buf), nil
} }

View File

@@ -163,13 +163,9 @@ func buildKey(pub *PublicKey, zb []byte, curveOID, fingerprint []byte, stripLead
if _, err := param.Write([]byte("Anonymous Sender ")); err != nil { if _, err := param.Write([]byte("Anonymous Sender ")); err != nil {
return nil, err return nil, err
} }
// For v5 keys, the 20 leftmost octets of the fingerprint are used. if _, err := param.Write(fingerprint[:]); err != nil {
if _, err := param.Write(fingerprint[:20]); err != nil {
return nil, err return nil, err
} }
if param.Len()-len(curveOID) != 45 {
return nil, errors.New("ecdh: malformed KDF Param")
}
// MB = Hash ( 00 || 00 || 00 || 01 || ZB || Param ); // MB = Hash ( 00 || 00 || 00 || 01 || ZB || Param );
h := pub.KDF.Hash.New() h := pub.KDF.Hash.New()

View File

@@ -0,0 +1,115 @@
// Package ed25519 implements the ed25519 signature algorithm for OpenPGP
// as defined in the Open PGP crypto refresh.
package ed25519
import (
"crypto/subtle"
"io"
"github.com/ProtonMail/go-crypto/openpgp/errors"
ed25519lib "github.com/cloudflare/circl/sign/ed25519"
)
const (
// PublicKeySize is the size, in bytes, of public keys in this package.
PublicKeySize = ed25519lib.PublicKeySize
// SeedSize is the size, in bytes, of private key seeds.
// The private key representation used by RFC 8032.
SeedSize = ed25519lib.SeedSize
// SignatureSize is the size, in bytes, of signatures generated and verified by this package.
SignatureSize = ed25519lib.SignatureSize
)
type PublicKey struct {
// Point represents the elliptic curve point of the public key.
Point []byte
}
type PrivateKey struct {
PublicKey
// Key the private key representation by RFC 8032,
// encoded as seed | pub key point.
Key []byte
}
// NewPublicKey creates a new empty ed25519 public key.
func NewPublicKey() *PublicKey {
return &PublicKey{}
}
// NewPrivateKey creates a new empty private key referencing the public key.
func NewPrivateKey(key PublicKey) *PrivateKey {
return &PrivateKey{
PublicKey: key,
}
}
// Seed returns the ed25519 private key secret seed.
// The private key representation by RFC 8032.
func (pk *PrivateKey) Seed() []byte {
return pk.Key[:SeedSize]
}
// MarshalByteSecret returns the underlying 32 byte seed of the private key.
func (pk *PrivateKey) MarshalByteSecret() []byte {
return pk.Seed()
}
// UnmarshalByteSecret computes the private key from the secret seed
// and stores it in the private key object.
func (sk *PrivateKey) UnmarshalByteSecret(seed []byte) error {
sk.Key = ed25519lib.NewKeyFromSeed(seed)
return nil
}
// GenerateKey generates a fresh private key with the provided randomness source.
func GenerateKey(rand io.Reader) (*PrivateKey, error) {
publicKey, privateKey, err := ed25519lib.GenerateKey(rand)
if err != nil {
return nil, err
}
privateKeyOut := new(PrivateKey)
privateKeyOut.PublicKey.Point = publicKey[:]
privateKeyOut.Key = privateKey[:]
return privateKeyOut, nil
}
// Sign signs a message with the ed25519 algorithm.
// priv MUST be a valid key! Check this with Validate() before use.
func Sign(priv *PrivateKey, message []byte) ([]byte, error) {
return ed25519lib.Sign(priv.Key, message), nil
}
// Verify verifies an ed25519 signature.
func Verify(pub *PublicKey, message []byte, signature []byte) bool {
return ed25519lib.Verify(pub.Point, message, signature)
}
// Validate checks if the ed25519 private key is valid.
func Validate(priv *PrivateKey) error {
expectedPrivateKey := ed25519lib.NewKeyFromSeed(priv.Seed())
if subtle.ConstantTimeCompare(priv.Key, expectedPrivateKey) == 0 {
return errors.KeyInvalidError("ed25519: invalid ed25519 secret")
}
if subtle.ConstantTimeCompare(priv.PublicKey.Point, expectedPrivateKey[SeedSize:]) == 0 {
return errors.KeyInvalidError("ed25519: invalid ed25519 public key")
}
return nil
}
// ENCODING/DECODING signature:
// WriteSignature encodes and writes an ed25519 signature to writer.
func WriteSignature(writer io.Writer, signature []byte) error {
_, err := writer.Write(signature)
return err
}
// ReadSignature decodes an ed25519 signature from a reader.
func ReadSignature(reader io.Reader) ([]byte, error) {
signature := make([]byte, SignatureSize)
if _, err := io.ReadFull(reader, signature); err != nil {
return nil, err
}
return signature, nil
}

View File

@@ -0,0 +1,119 @@
// Package ed448 implements the ed448 signature algorithm for OpenPGP
// as defined in the Open PGP crypto refresh.
package ed448
import (
"crypto/subtle"
"io"
"github.com/ProtonMail/go-crypto/openpgp/errors"
ed448lib "github.com/cloudflare/circl/sign/ed448"
)
const (
// PublicKeySize is the size, in bytes, of public keys in this package.
PublicKeySize = ed448lib.PublicKeySize
// SeedSize is the size, in bytes, of private key seeds.
// The private key representation used by RFC 8032.
SeedSize = ed448lib.SeedSize
// SignatureSize is the size, in bytes, of signatures generated and verified by this package.
SignatureSize = ed448lib.SignatureSize
)
type PublicKey struct {
// Point represents the elliptic curve point of the public key.
Point []byte
}
type PrivateKey struct {
PublicKey
// Key the private key representation by RFC 8032,
// encoded as seed | public key point.
Key []byte
}
// NewPublicKey creates a new empty ed448 public key.
func NewPublicKey() *PublicKey {
return &PublicKey{}
}
// NewPrivateKey creates a new empty private key referencing the public key.
func NewPrivateKey(key PublicKey) *PrivateKey {
return &PrivateKey{
PublicKey: key,
}
}
// Seed returns the ed448 private key secret seed.
// The private key representation by RFC 8032.
func (pk *PrivateKey) Seed() []byte {
return pk.Key[:SeedSize]
}
// MarshalByteSecret returns the underlying seed of the private key.
func (pk *PrivateKey) MarshalByteSecret() []byte {
return pk.Seed()
}
// UnmarshalByteSecret computes the private key from the secret seed
// and stores it in the private key object.
func (sk *PrivateKey) UnmarshalByteSecret(seed []byte) error {
sk.Key = ed448lib.NewKeyFromSeed(seed)
return nil
}
// GenerateKey generates a fresh private key with the provided randomness source.
func GenerateKey(rand io.Reader) (*PrivateKey, error) {
publicKey, privateKey, err := ed448lib.GenerateKey(rand)
if err != nil {
return nil, err
}
privateKeyOut := new(PrivateKey)
privateKeyOut.PublicKey.Point = publicKey[:]
privateKeyOut.Key = privateKey[:]
return privateKeyOut, nil
}
// Sign signs a message with the ed448 algorithm.
// priv MUST be a valid key! Check this with Validate() before use.
func Sign(priv *PrivateKey, message []byte) ([]byte, error) {
// Ed448 is used with the empty string as a context string.
// See https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh-08#section-13.7
return ed448lib.Sign(priv.Key, message, ""), nil
}
// Verify verifies a ed448 signature
func Verify(pub *PublicKey, message []byte, signature []byte) bool {
// Ed448 is used with the empty string as a context string.
// See https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh-08#section-13.7
return ed448lib.Verify(pub.Point, message, signature, "")
}
// Validate checks if the ed448 private key is valid
func Validate(priv *PrivateKey) error {
expectedPrivateKey := ed448lib.NewKeyFromSeed(priv.Seed())
if subtle.ConstantTimeCompare(priv.Key, expectedPrivateKey) == 0 {
return errors.KeyInvalidError("ed448: invalid ed448 secret")
}
if subtle.ConstantTimeCompare(priv.PublicKey.Point, expectedPrivateKey[SeedSize:]) == 0 {
return errors.KeyInvalidError("ed448: invalid ed448 public key")
}
return nil
}
// ENCODING/DECODING signature:
// WriteSignature encodes and writes an ed448 signature to writer.
func WriteSignature(writer io.Writer, signature []byte) error {
_, err := writer.Write(signature)
return err
}
// ReadSignature decodes an ed448 signature from a reader.
func ReadSignature(reader io.Reader) ([]byte, error) {
signature := make([]byte, SignatureSize)
if _, err := io.ReadFull(reader, signature); err != nil {
return nil, err
}
return signature, nil
}

View File

@@ -6,9 +6,22 @@
package errors // import "github.com/ProtonMail/go-crypto/openpgp/errors" package errors // import "github.com/ProtonMail/go-crypto/openpgp/errors"
import ( import (
"fmt"
"strconv" "strconv"
) )
var (
// ErrDecryptSessionKeyParsing is a generic error message for parsing errors in decrypted data
// to reduce the risk of oracle attacks.
ErrDecryptSessionKeyParsing = DecryptWithSessionKeyError("parsing error")
// ErrAEADTagVerification is returned if one of the tag verifications in SEIPDv2 fails
ErrAEADTagVerification error = DecryptWithSessionKeyError("AEAD tag verification failed")
// ErrMDCHashMismatch
ErrMDCHashMismatch error = SignatureError("MDC hash mismatch")
// ErrMDCMissing
ErrMDCMissing error = SignatureError("MDC packet not found")
)
// A StructuralError is returned when OpenPGP data is found to be syntactically // A StructuralError is returned when OpenPGP data is found to be syntactically
// invalid. // invalid.
type StructuralError string type StructuralError string
@@ -17,6 +30,34 @@ func (s StructuralError) Error() string {
return "openpgp: invalid data: " + string(s) return "openpgp: invalid data: " + string(s)
} }
// A DecryptWithSessionKeyError is returned when a failure occurs when reading from symmetrically decrypted data or
// an authentication tag verification fails.
// Such an error indicates that the supplied session key is likely wrong or the data got corrupted.
type DecryptWithSessionKeyError string
func (s DecryptWithSessionKeyError) Error() string {
return "openpgp: decryption with session key failed: " + string(s)
}
// HandleSensitiveParsingError handles parsing errors when reading data from potentially decrypted data.
// The function makes parsing errors generic to reduce the risk of oracle attacks in SEIPDv1.
func HandleSensitiveParsingError(err error, decrypted bool) error {
if !decrypted {
// Data was not encrypted so we return the inner error.
return err
}
// The data is read from a stream that decrypts using a session key;
// therefore, we need to handle parsing errors appropriately.
// This is essential to mitigate the risk of oracle attacks.
if decError, ok := err.(*DecryptWithSessionKeyError); ok {
return decError
}
if decError, ok := err.(DecryptWithSessionKeyError); ok {
return decError
}
return ErrDecryptSessionKeyParsing
}
// UnsupportedError indicates that, although the OpenPGP data is valid, it // UnsupportedError indicates that, although the OpenPGP data is valid, it
// makes use of currently unimplemented features. // makes use of currently unimplemented features.
type UnsupportedError string type UnsupportedError string
@@ -41,9 +82,6 @@ func (b SignatureError) Error() string {
return "openpgp: invalid signature: " + string(b) return "openpgp: invalid signature: " + string(b)
} }
var ErrMDCHashMismatch error = SignatureError("MDC hash mismatch")
var ErrMDCMissing error = SignatureError("MDC packet not found")
type signatureExpiredError int type signatureExpiredError int
func (se signatureExpiredError) Error() string { func (se signatureExpiredError) Error() string {
@@ -58,6 +96,14 @@ func (ke keyExpiredError) Error() string {
return "openpgp: key expired" return "openpgp: key expired"
} }
var ErrSignatureOlderThanKey error = signatureOlderThanKeyError(0)
type signatureOlderThanKeyError int
func (ske signatureOlderThanKeyError) Error() string {
return "openpgp: signature is older than the key"
}
var ErrKeyExpired error = keyExpiredError(0) var ErrKeyExpired error = keyExpiredError(0)
type keyIncorrectError int type keyIncorrectError int
@@ -92,12 +138,24 @@ func (keyRevokedError) Error() string {
var ErrKeyRevoked error = keyRevokedError(0) var ErrKeyRevoked error = keyRevokedError(0)
type WeakAlgorithmError string
func (e WeakAlgorithmError) Error() string {
return "openpgp: weak algorithms are rejected: " + string(e)
}
type UnknownPacketTypeError uint8 type UnknownPacketTypeError uint8
func (upte UnknownPacketTypeError) Error() string { func (upte UnknownPacketTypeError) Error() string {
return "openpgp: unknown packet type: " + strconv.Itoa(int(upte)) return "openpgp: unknown packet type: " + strconv.Itoa(int(upte))
} }
type CriticalUnknownPacketTypeError uint8
func (upte CriticalUnknownPacketTypeError) Error() string {
return "openpgp: unknown critical packet type: " + strconv.Itoa(int(upte))
}
// AEADError indicates that there is a problem when initializing or using a // AEADError indicates that there is a problem when initializing or using a
// AEAD instance, configuration struct, nonces or index values. // AEAD instance, configuration struct, nonces or index values.
type AEADError string type AEADError string
@@ -114,3 +172,39 @@ type ErrDummyPrivateKey string
func (dke ErrDummyPrivateKey) Error() string { func (dke ErrDummyPrivateKey) Error() string {
return "openpgp: s2k GNU dummy key: " + string(dke) return "openpgp: s2k GNU dummy key: " + string(dke)
} }
// ErrMalformedMessage results when the packet sequence is incorrect
type ErrMalformedMessage string
func (dke ErrMalformedMessage) Error() string {
return "openpgp: malformed message " + string(dke)
}
type messageTooLargeError int
func (e messageTooLargeError) Error() string {
return "openpgp: decompressed message size exceeds provided limit"
}
// ErrMessageTooLarge is returned if the read data from
// a compressed packet exceeds the provided limit.
var ErrMessageTooLarge error = messageTooLargeError(0)
// ErrEncryptionKeySelection is returned if encryption key selection fails (v2 API).
type ErrEncryptionKeySelection struct {
PrimaryKeyId string
PrimaryKeyErr error
EncSelectionKeyId *string
EncSelectionErr error
}
func (eks ErrEncryptionKeySelection) Error() string {
prefix := fmt.Sprintf("openpgp: key selection for primary key %s:", eks.PrimaryKeyId)
if eks.PrimaryKeyErr != nil {
return fmt.Sprintf("%s invalid primary key: %s", prefix, eks.PrimaryKeyErr)
}
if eks.EncSelectionKeyId != nil {
return fmt.Sprintf("%s invalid encryption key %s: %s", prefix, *eks.EncSelectionKeyId, eks.EncSelectionErr)
}
return fmt.Sprintf("%s no encryption key: %s", prefix, eks.EncSelectionErr)
}

View File

@@ -51,24 +51,14 @@ func (sk CipherFunction) Id() uint8 {
return uint8(sk) return uint8(sk)
} }
var keySizeByID = map[uint8]int{
TripleDES.Id(): 24,
CAST5.Id(): cast5.KeySize,
AES128.Id(): 16,
AES192.Id(): 24,
AES256.Id(): 32,
}
// KeySize returns the key size, in bytes, of cipher. // KeySize returns the key size, in bytes, of cipher.
func (cipher CipherFunction) KeySize() int { func (cipher CipherFunction) KeySize() int {
switch cipher { switch cipher {
case TripleDES:
return 24
case CAST5: case CAST5:
return cast5.KeySize return cast5.KeySize
case AES128: case AES128:
return 16 return 16
case AES192: case AES192, TripleDES:
return 24 return 24
case AES256: case AES256:
return 32 return 32

View File

@@ -4,11 +4,14 @@ package ecc
import ( import (
"bytes" "bytes"
"crypto/elliptic" "crypto/elliptic"
"github.com/ProtonMail/go-crypto/bitcurves" "github.com/ProtonMail/go-crypto/bitcurves"
"github.com/ProtonMail/go-crypto/brainpool" "github.com/ProtonMail/go-crypto/brainpool"
"github.com/ProtonMail/go-crypto/openpgp/internal/encoding" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding"
) )
const Curve25519GenName = "Curve25519"
type CurveInfo struct { type CurveInfo struct {
GenName string GenName string
Oid *encoding.OID Oid *encoding.OID
@@ -42,19 +45,19 @@ var Curves = []CurveInfo{
}, },
{ {
// Curve25519 // Curve25519
GenName: "Curve25519", GenName: Curve25519GenName,
Oid: encoding.NewOID([]byte{0x2B, 0x06, 0x01, 0x04, 0x01, 0x97, 0x55, 0x01, 0x05, 0x01}), Oid: encoding.NewOID([]byte{0x2B, 0x06, 0x01, 0x04, 0x01, 0x97, 0x55, 0x01, 0x05, 0x01}),
Curve: NewCurve25519(), Curve: NewCurve25519(),
}, },
{ {
// X448 // x448
GenName: "Curve448", GenName: "Curve448",
Oid: encoding.NewOID([]byte{0x2B, 0x65, 0x6F}), Oid: encoding.NewOID([]byte{0x2B, 0x65, 0x6F}),
Curve: NewX448(), Curve: NewX448(),
}, },
{ {
// Ed25519 // Ed25519
GenName: "Curve25519", GenName: Curve25519GenName,
Oid: encoding.NewOID([]byte{0x2B, 0x06, 0x01, 0x04, 0x01, 0xDA, 0x47, 0x0F, 0x01}), Oid: encoding.NewOID([]byte{0x2B, 0x06, 0x01, 0x04, 0x01, 0xDA, 0x47, 0x0F, 0x01}),
Curve: NewEd25519(), Curve: NewEd25519(),
}, },

View File

@@ -2,6 +2,7 @@
package ecc package ecc
import ( import (
"bytes"
"crypto/subtle" "crypto/subtle"
"io" "io"
@@ -90,7 +91,14 @@ func (c *ed25519) GenerateEdDSA(rand io.Reader) (pub, priv []byte, err error) {
} }
func getEd25519Sk(publicKey, privateKey []byte) ed25519lib.PrivateKey { func getEd25519Sk(publicKey, privateKey []byte) ed25519lib.PrivateKey {
return append(privateKey, publicKey...) privateKeyCap, privateKeyLen, publicKeyLen := cap(privateKey), len(privateKey), len(publicKey)
if privateKeyCap >= privateKeyLen+publicKeyLen &&
bytes.Equal(privateKey[privateKeyLen:privateKeyLen+publicKeyLen], publicKey) {
return privateKey[:privateKeyLen+publicKeyLen]
}
return append(privateKey[:privateKeyLen:privateKeyLen], publicKey...)
} }
func (c *ed25519) Sign(publicKey, privateKey, message []byte) (sig []byte, err error) { func (c *ed25519) Sign(publicKey, privateKey, message []byte) (sig []byte, err error) {

View File

@@ -2,6 +2,7 @@
package ecc package ecc
import ( import (
"bytes"
"crypto/subtle" "crypto/subtle"
"io" "io"
@@ -84,7 +85,14 @@ func (c *ed448) GenerateEdDSA(rand io.Reader) (pub, priv []byte, err error) {
} }
func getEd448Sk(publicKey, privateKey []byte) ed448lib.PrivateKey { func getEd448Sk(publicKey, privateKey []byte) ed448lib.PrivateKey {
return append(privateKey, publicKey...) privateKeyCap, privateKeyLen, publicKeyLen := cap(privateKey), len(privateKey), len(publicKey)
if privateKeyCap >= privateKeyLen+publicKeyLen &&
bytes.Equal(privateKey[privateKeyLen:privateKeyLen+publicKeyLen], publicKey) {
return privateKey[:privateKeyLen+publicKeyLen]
}
return append(privateKey[:privateKeyLen:privateKeyLen], publicKey...)
} }
func (c *ed448) Sign(publicKey, privateKey, message []byte) (sig []byte, err error) { func (c *ed448) Sign(publicKey, privateKey, message []byte) (sig []byte, err error) {

View File

@@ -15,11 +15,15 @@ import (
"github.com/ProtonMail/go-crypto/openpgp/ecdh" "github.com/ProtonMail/go-crypto/openpgp/ecdh"
"github.com/ProtonMail/go-crypto/openpgp/ecdsa" "github.com/ProtonMail/go-crypto/openpgp/ecdsa"
"github.com/ProtonMail/go-crypto/openpgp/ed25519"
"github.com/ProtonMail/go-crypto/openpgp/ed448"
"github.com/ProtonMail/go-crypto/openpgp/eddsa" "github.com/ProtonMail/go-crypto/openpgp/eddsa"
"github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/errors"
"github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm"
"github.com/ProtonMail/go-crypto/openpgp/internal/ecc" "github.com/ProtonMail/go-crypto/openpgp/internal/ecc"
"github.com/ProtonMail/go-crypto/openpgp/packet" "github.com/ProtonMail/go-crypto/openpgp/packet"
"github.com/ProtonMail/go-crypto/openpgp/x25519"
"github.com/ProtonMail/go-crypto/openpgp/x448"
) )
// NewEntity returns an Entity that contains a fresh RSA/RSA keypair with a // NewEntity returns an Entity that contains a fresh RSA/RSA keypair with a
@@ -36,8 +40,10 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err
return nil, err return nil, err
} }
primary := packet.NewSignerPrivateKey(creationTime, primaryPrivRaw) primary := packet.NewSignerPrivateKey(creationTime, primaryPrivRaw)
if config != nil && config.V5Keys { if config.V6() {
primary.UpgradeToV5() if err := primary.UpgradeToV6(); err != nil {
return nil, err
}
} }
e := &Entity{ e := &Entity{
@@ -45,9 +51,25 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err
PrivateKey: primary, PrivateKey: primary,
Identities: make(map[string]*Identity), Identities: make(map[string]*Identity),
Subkeys: []Subkey{}, Subkeys: []Subkey{},
Signatures: []*packet.Signature{},
} }
err = e.addUserId(name, comment, email, config, creationTime, keyLifetimeSecs) if config.V6() {
// In v6 keys algorithm preferences should be stored in direct key signatures
selfSignature := createSignaturePacket(&primary.PublicKey, packet.SigTypeDirectSignature, config)
err = writeKeyProperties(selfSignature, creationTime, keyLifetimeSecs, config)
if err != nil {
return nil, err
}
err = selfSignature.SignDirectKeyBinding(&primary.PublicKey, primary, config)
if err != nil {
return nil, err
}
e.Signatures = append(e.Signatures, selfSignature)
e.SelfSignature = selfSignature
}
err = e.addUserId(name, comment, email, config, creationTime, keyLifetimeSecs, !config.V6())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -65,32 +87,19 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err
func (t *Entity) AddUserId(name, comment, email string, config *packet.Config) error { func (t *Entity) AddUserId(name, comment, email string, config *packet.Config) error {
creationTime := config.Now() creationTime := config.Now()
keyLifetimeSecs := config.KeyLifetime() keyLifetimeSecs := config.KeyLifetime()
return t.addUserId(name, comment, email, config, creationTime, keyLifetimeSecs) return t.addUserId(name, comment, email, config, creationTime, keyLifetimeSecs, !config.V6())
} }
func (t *Entity) addUserId(name, comment, email string, config *packet.Config, creationTime time.Time, keyLifetimeSecs uint32) error { func writeKeyProperties(selfSignature *packet.Signature, creationTime time.Time, keyLifetimeSecs uint32, config *packet.Config) error {
uid := packet.NewUserId(name, comment, email) advertiseAead := config.AEAD() != nil
if uid == nil {
return errors.InvalidArgumentError("user id field contained invalid characters")
}
if _, ok := t.Identities[uid.Id]; ok {
return errors.InvalidArgumentError("user id exist")
}
primary := t.PrivateKey
isPrimaryId := len(t.Identities) == 0
selfSignature := createSignaturePacket(&primary.PublicKey, packet.SigTypePositiveCert, config)
selfSignature.CreationTime = creationTime selfSignature.CreationTime = creationTime
selfSignature.KeyLifetimeSecs = &keyLifetimeSecs selfSignature.KeyLifetimeSecs = &keyLifetimeSecs
selfSignature.IsPrimaryId = &isPrimaryId
selfSignature.FlagsValid = true selfSignature.FlagsValid = true
selfSignature.FlagSign = true selfSignature.FlagSign = true
selfSignature.FlagCertify = true selfSignature.FlagCertify = true
selfSignature.SEIPDv1 = true // true by default, see 5.8 vs. 5.14 selfSignature.SEIPDv1 = true // true by default, see 5.8 vs. 5.14
selfSignature.SEIPDv2 = config.AEAD() != nil selfSignature.SEIPDv2 = advertiseAead
// Set the PreferredHash for the SelfSignature from the packet.Config. // Set the PreferredHash for the SelfSignature from the packet.Config.
// If it is not the must-implement algorithm from rfc4880bis, append that. // If it is not the must-implement algorithm from rfc4880bis, append that.
@@ -119,18 +128,44 @@ func (t *Entity) addUserId(name, comment, email string, config *packet.Config, c
selfSignature.PreferredCompression = append(selfSignature.PreferredCompression, uint8(config.Compression())) selfSignature.PreferredCompression = append(selfSignature.PreferredCompression, uint8(config.Compression()))
} }
// And for DefaultMode. if advertiseAead {
modes := []uint8{uint8(config.AEAD().Mode())} // Get the preferred AEAD mode from the packet.Config.
if config.AEAD().Mode() != packet.AEADModeOCB { // If it is not the must-implement algorithm from rfc9580, append that.
modes = append(modes, uint8(packet.AEADModeOCB)) modes := []uint8{uint8(config.AEAD().Mode())}
} if config.AEAD().Mode() != packet.AEADModeOCB {
modes = append(modes, uint8(packet.AEADModeOCB))
}
// For preferred (AES256, GCM), we'll generate (AES256, GCM), (AES256, OCB), (AES128, GCM), (AES128, OCB) // For preferred (AES256, GCM), we'll generate (AES256, GCM), (AES256, OCB), (AES128, GCM), (AES128, OCB)
for _, cipher := range selfSignature.PreferredSymmetric { for _, cipher := range selfSignature.PreferredSymmetric {
for _, mode := range modes { for _, mode := range modes {
selfSignature.PreferredCipherSuites = append(selfSignature.PreferredCipherSuites, [2]uint8{cipher, mode}) selfSignature.PreferredCipherSuites = append(selfSignature.PreferredCipherSuites, [2]uint8{cipher, mode})
}
} }
} }
return nil
}
func (t *Entity) addUserId(name, comment, email string, config *packet.Config, creationTime time.Time, keyLifetimeSecs uint32, writeProperties bool) error {
uid := packet.NewUserId(name, comment, email)
if uid == nil {
return errors.InvalidArgumentError("user id field contained invalid characters")
}
if _, ok := t.Identities[uid.Id]; ok {
return errors.InvalidArgumentError("user id exist")
}
primary := t.PrivateKey
isPrimaryId := len(t.Identities) == 0
selfSignature := createSignaturePacket(&primary.PublicKey, packet.SigTypePositiveCert, config)
if writeProperties {
err := writeKeyProperties(selfSignature, creationTime, keyLifetimeSecs, config)
if err != nil {
return err
}
}
selfSignature.IsPrimaryId = &isPrimaryId
// User ID binding signature // User ID binding signature
err := selfSignature.SignUserId(uid.Id, &primary.PublicKey, primary, config) err := selfSignature.SignUserId(uid.Id, &primary.PublicKey, primary, config)
@@ -158,8 +193,10 @@ func (e *Entity) AddSigningSubkey(config *packet.Config) error {
} }
sub := packet.NewSignerPrivateKey(creationTime, subPrivRaw) sub := packet.NewSignerPrivateKey(creationTime, subPrivRaw)
sub.IsSubkey = true sub.IsSubkey = true
if config != nil && config.V5Keys { if config.V6() {
sub.UpgradeToV5() if err := sub.UpgradeToV6(); err != nil {
return err
}
} }
subkey := Subkey{ subkey := Subkey{
@@ -203,8 +240,10 @@ func (e *Entity) addEncryptionSubkey(config *packet.Config, creationTime time.Ti
} }
sub := packet.NewDecrypterPrivateKey(creationTime, subPrivRaw) sub := packet.NewDecrypterPrivateKey(creationTime, subPrivRaw)
sub.IsSubkey = true sub.IsSubkey = true
if config != nil && config.V5Keys { if config.V6() {
sub.UpgradeToV5() if err := sub.UpgradeToV6(); err != nil {
return err
}
} }
subkey := Subkey{ subkey := Subkey{
@@ -242,6 +281,11 @@ func newSigner(config *packet.Config) (signer interface{}, err error) {
} }
return rsa.GenerateKey(config.Random(), bits) return rsa.GenerateKey(config.Random(), bits)
case packet.PubKeyAlgoEdDSA: case packet.PubKeyAlgoEdDSA:
if config.V6() {
// Implementations MUST NOT accept or generate v6 key material
// using the deprecated OIDs.
return nil, errors.InvalidArgumentError("EdDSALegacy cannot be used for v6 keys")
}
curve := ecc.FindEdDSAByGenName(string(config.CurveName())) curve := ecc.FindEdDSAByGenName(string(config.CurveName()))
if curve == nil { if curve == nil {
return nil, errors.InvalidArgumentError("unsupported curve") return nil, errors.InvalidArgumentError("unsupported curve")
@@ -263,6 +307,18 @@ func newSigner(config *packet.Config) (signer interface{}, err error) {
return nil, err return nil, err
} }
return priv, nil return priv, nil
case packet.PubKeyAlgoEd25519:
priv, err := ed25519.GenerateKey(config.Random())
if err != nil {
return nil, err
}
return priv, nil
case packet.PubKeyAlgoEd448:
priv, err := ed448.GenerateKey(config.Random())
if err != nil {
return nil, err
}
return priv, nil
default: default:
return nil, errors.InvalidArgumentError("unsupported public key algorithm") return nil, errors.InvalidArgumentError("unsupported public key algorithm")
} }
@@ -285,6 +341,13 @@ func newDecrypter(config *packet.Config) (decrypter interface{}, err error) {
case packet.PubKeyAlgoEdDSA, packet.PubKeyAlgoECDSA: case packet.PubKeyAlgoEdDSA, packet.PubKeyAlgoECDSA:
fallthrough // When passing EdDSA or ECDSA, we generate an ECDH subkey fallthrough // When passing EdDSA or ECDSA, we generate an ECDH subkey
case packet.PubKeyAlgoECDH: case packet.PubKeyAlgoECDH:
if config.V6() &&
(config.CurveName() == packet.Curve25519 ||
config.CurveName() == packet.Curve448) {
// Implementations MUST NOT accept or generate v6 key material
// using the deprecated OIDs.
return nil, errors.InvalidArgumentError("ECDH with Curve25519/448 legacy cannot be used for v6 keys")
}
var kdf = ecdh.KDF{ var kdf = ecdh.KDF{
Hash: algorithm.SHA512, Hash: algorithm.SHA512,
Cipher: algorithm.AES256, Cipher: algorithm.AES256,
@@ -294,6 +357,10 @@ func newDecrypter(config *packet.Config) (decrypter interface{}, err error) {
return nil, errors.InvalidArgumentError("unsupported curve") return nil, errors.InvalidArgumentError("unsupported curve")
} }
return ecdh.GenerateKey(config.Random(), curve, kdf) return ecdh.GenerateKey(config.Random(), curve, kdf)
case packet.PubKeyAlgoEd25519, packet.PubKeyAlgoX25519: // When passing Ed25519, we generate an x25519 subkey
return x25519.GenerateKey(config.Random())
case packet.PubKeyAlgoEd448, packet.PubKeyAlgoX448: // When passing Ed448, we generate an x448 subkey
return x448.GenerateKey(config.Random())
default: default:
return nil, errors.InvalidArgumentError("unsupported public key algorithm") return nil, errors.InvalidArgumentError("unsupported public key algorithm")
} }
@@ -302,7 +369,7 @@ func newDecrypter(config *packet.Config) (decrypter interface{}, err error) {
var bigOne = big.NewInt(1) var bigOne = big.NewInt(1)
// generateRSAKeyWithPrimes generates a multi-prime RSA keypair of the // generateRSAKeyWithPrimes generates a multi-prime RSA keypair of the
// given bit size, using the given random source and prepopulated primes. // given bit size, using the given random source and pre-populated primes.
func generateRSAKeyWithPrimes(random io.Reader, nprimes int, bits int, prepopulatedPrimes []*big.Int) (*rsa.PrivateKey, error) { func generateRSAKeyWithPrimes(random io.Reader, nprimes int, bits int, prepopulatedPrimes []*big.Int) (*rsa.PrivateKey, error) {
priv := new(rsa.PrivateKey) priv := new(rsa.PrivateKey)
priv.E = 65537 priv.E = 65537

View File

@@ -6,6 +6,7 @@ package openpgp
import ( import (
goerrors "errors" goerrors "errors"
"fmt"
"io" "io"
"time" "time"
@@ -24,11 +25,13 @@ var PrivateKeyType = "PGP PRIVATE KEY BLOCK"
// (which must be a signing key), one or more identities claimed by that key, // (which must be a signing key), one or more identities claimed by that key,
// and zero or more subkeys, which may be encryption keys. // and zero or more subkeys, which may be encryption keys.
type Entity struct { type Entity struct {
PrimaryKey *packet.PublicKey PrimaryKey *packet.PublicKey
PrivateKey *packet.PrivateKey PrivateKey *packet.PrivateKey
Identities map[string]*Identity // indexed by Identity.Name Identities map[string]*Identity // indexed by Identity.Name
Revocations []*packet.Signature Revocations []*packet.Signature
Subkeys []Subkey Subkeys []Subkey
SelfSignature *packet.Signature // Direct-key self signature of the PrimaryKey (contains primary key properties in v6)
Signatures []*packet.Signature // all (potentially unverified) self-signatures, revocations, and third-party signatures
} }
// An Identity represents an identity claimed by an Entity and zero or more // An Identity represents an identity claimed by an Entity and zero or more
@@ -120,12 +123,12 @@ func shouldPreferIdentity(existingId, potentialNewId *Identity) bool {
// given Entity. // given Entity.
func (e *Entity) EncryptionKey(now time.Time) (Key, bool) { func (e *Entity) EncryptionKey(now time.Time) (Key, bool) {
// Fail to find any encryption key if the... // Fail to find any encryption key if the...
i := e.PrimaryIdentity() primarySelfSignature, primaryIdentity := e.PrimarySelfSignature()
if e.PrimaryKey.KeyExpired(i.SelfSignature, now) || // primary key has expired if primarySelfSignature == nil || // no self-signature found
i.SelfSignature == nil || // user ID has no self-signature e.PrimaryKey.KeyExpired(primarySelfSignature, now) || // primary key has expired
i.SelfSignature.SigExpired(now) || // user ID self-signature has expired
e.Revoked(now) || // primary key has been revoked e.Revoked(now) || // primary key has been revoked
i.Revoked(now) { // user ID has been revoked primarySelfSignature.SigExpired(now) || // user ID or or direct self-signature has expired
(primaryIdentity != nil && primaryIdentity.Revoked(now)) { // user ID has been revoked (for v4 keys)
return Key{}, false return Key{}, false
} }
@@ -152,9 +155,9 @@ func (e *Entity) EncryptionKey(now time.Time) (Key, bool) {
// If we don't have any subkeys for encryption and the primary key // If we don't have any subkeys for encryption and the primary key
// is marked as OK to encrypt with, then we can use it. // is marked as OK to encrypt with, then we can use it.
if i.SelfSignature.FlagsValid && i.SelfSignature.FlagEncryptCommunications && if primarySelfSignature.FlagsValid && primarySelfSignature.FlagEncryptCommunications &&
e.PrimaryKey.PubKeyAlgo.CanEncrypt() { e.PrimaryKey.PubKeyAlgo.CanEncrypt() {
return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature, e.Revocations}, true return Key{e, e.PrimaryKey, e.PrivateKey, primarySelfSignature, e.Revocations}, true
} }
return Key{}, false return Key{}, false
@@ -186,12 +189,12 @@ func (e *Entity) SigningKeyById(now time.Time, id uint64) (Key, bool) {
func (e *Entity) signingKeyByIdUsage(now time.Time, id uint64, flags int) (Key, bool) { func (e *Entity) signingKeyByIdUsage(now time.Time, id uint64, flags int) (Key, bool) {
// Fail to find any signing key if the... // Fail to find any signing key if the...
i := e.PrimaryIdentity() primarySelfSignature, primaryIdentity := e.PrimarySelfSignature()
if e.PrimaryKey.KeyExpired(i.SelfSignature, now) || // primary key has expired if primarySelfSignature == nil || // no self-signature found
i.SelfSignature == nil || // user ID has no self-signature e.PrimaryKey.KeyExpired(primarySelfSignature, now) || // primary key has expired
i.SelfSignature.SigExpired(now) || // user ID self-signature has expired
e.Revoked(now) || // primary key has been revoked e.Revoked(now) || // primary key has been revoked
i.Revoked(now) { // user ID has been revoked primarySelfSignature.SigExpired(now) || // user ID or direct self-signature has expired
(primaryIdentity != nil && primaryIdentity.Revoked(now)) { // user ID has been revoked (for v4 keys)
return Key{}, false return Key{}, false
} }
@@ -220,12 +223,12 @@ func (e *Entity) signingKeyByIdUsage(now time.Time, id uint64, flags int) (Key,
// If we don't have any subkeys for signing and the primary key // If we don't have any subkeys for signing and the primary key
// is marked as OK to sign with, then we can use it. // is marked as OK to sign with, then we can use it.
if i.SelfSignature.FlagsValid && if primarySelfSignature.FlagsValid &&
(flags&packet.KeyFlagCertify == 0 || i.SelfSignature.FlagCertify) && (flags&packet.KeyFlagCertify == 0 || primarySelfSignature.FlagCertify) &&
(flags&packet.KeyFlagSign == 0 || i.SelfSignature.FlagSign) && (flags&packet.KeyFlagSign == 0 || primarySelfSignature.FlagSign) &&
e.PrimaryKey.PubKeyAlgo.CanSign() && e.PrimaryKey.PubKeyAlgo.CanSign() &&
(id == 0 || e.PrimaryKey.KeyId == id) { (id == 0 || e.PrimaryKey.KeyId == id) {
return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature, e.Revocations}, true return Key{e, e.PrimaryKey, e.PrivateKey, primarySelfSignature, e.Revocations}, true
} }
// No keys with a valid Signing Flag or no keys matched the id passed in // No keys with a valid Signing Flag or no keys matched the id passed in
@@ -259,7 +262,7 @@ func (e *Entity) EncryptPrivateKeys(passphrase []byte, config *packet.Config) er
var keysToEncrypt []*packet.PrivateKey var keysToEncrypt []*packet.PrivateKey
// Add entity private key to encrypt. // Add entity private key to encrypt.
if e.PrivateKey != nil && !e.PrivateKey.Dummy() && !e.PrivateKey.Encrypted { if e.PrivateKey != nil && !e.PrivateKey.Dummy() && !e.PrivateKey.Encrypted {
keysToEncrypt = append(keysToEncrypt, e.PrivateKey) keysToEncrypt = append(keysToEncrypt, e.PrivateKey)
} }
// Add subkeys to encrypt. // Add subkeys to encrypt.
@@ -271,7 +274,7 @@ func (e *Entity) EncryptPrivateKeys(passphrase []byte, config *packet.Config) er
return packet.EncryptPrivateKeys(keysToEncrypt, passphrase, config) return packet.EncryptPrivateKeys(keysToEncrypt, passphrase, config)
} }
// DecryptPrivateKeys decrypts all encrypted keys in the entitiy with the given passphrase. // DecryptPrivateKeys decrypts all encrypted keys in the entity with the given passphrase.
// Avoids recomputation of similar s2k key derivations. Public keys and dummy keys are ignored, // Avoids recomputation of similar s2k key derivations. Public keys and dummy keys are ignored,
// and don't cause an error to be returned. // and don't cause an error to be returned.
func (e *Entity) DecryptPrivateKeys(passphrase []byte) error { func (e *Entity) DecryptPrivateKeys(passphrase []byte) error {
@@ -284,7 +287,7 @@ func (e *Entity) DecryptPrivateKeys(passphrase []byte) error {
// Add subkeys to decrypt. // Add subkeys to decrypt.
for _, sub := range e.Subkeys { for _, sub := range e.Subkeys {
if sub.PrivateKey != nil && !sub.PrivateKey.Dummy() && sub.PrivateKey.Encrypted { if sub.PrivateKey != nil && !sub.PrivateKey.Dummy() && sub.PrivateKey.Encrypted {
keysToDecrypt = append(keysToDecrypt, sub.PrivateKey) keysToDecrypt = append(keysToDecrypt, sub.PrivateKey)
} }
} }
return packet.DecryptPrivateKeys(keysToDecrypt, passphrase) return packet.DecryptPrivateKeys(keysToDecrypt, passphrase)
@@ -318,8 +321,7 @@ type EntityList []*Entity
func (el EntityList) KeysById(id uint64) (keys []Key) { func (el EntityList) KeysById(id uint64) (keys []Key) {
for _, e := range el { for _, e := range el {
if e.PrimaryKey.KeyId == id { if e.PrimaryKey.KeyId == id {
ident := e.PrimaryIdentity() selfSig, _ := e.PrimarySelfSignature()
selfSig := ident.SelfSignature
keys = append(keys, Key{e, e.PrimaryKey, e.PrivateKey, selfSig, e.Revocations}) keys = append(keys, Key{e, e.PrimaryKey, e.PrivateKey, selfSig, e.Revocations})
} }
@@ -441,7 +443,6 @@ func readToNextPublicKey(packets *packet.Reader) (err error) {
return return
} else if err != nil { } else if err != nil {
if _, ok := err.(errors.UnsupportedError); ok { if _, ok := err.(errors.UnsupportedError); ok {
err = nil
continue continue
} }
return return
@@ -479,6 +480,7 @@ func ReadEntity(packets *packet.Reader) (*Entity, error) {
} }
var revocations []*packet.Signature var revocations []*packet.Signature
var directSignatures []*packet.Signature
EachPacket: EachPacket:
for { for {
p, err := packets.Next() p, err := packets.Next()
@@ -497,9 +499,7 @@ EachPacket:
if pkt.SigType == packet.SigTypeKeyRevocation { if pkt.SigType == packet.SigTypeKeyRevocation {
revocations = append(revocations, pkt) revocations = append(revocations, pkt)
} else if pkt.SigType == packet.SigTypeDirectSignature { } else if pkt.SigType == packet.SigTypeDirectSignature {
// TODO: RFC4880 5.2.1 permits signatures directSignatures = append(directSignatures, pkt)
// directly on keys (eg. to bind additional
// revocation keys).
} }
// Else, ignoring the signature as it does not follow anything // Else, ignoring the signature as it does not follow anything
// we would know to attach it to. // we would know to attach it to.
@@ -522,12 +522,39 @@ EachPacket:
return nil, err return nil, err
} }
default: default:
// we ignore unknown packets // we ignore unknown packets.
} }
} }
if len(e.Identities) == 0 { if len(e.Identities) == 0 && e.PrimaryKey.Version < 6 {
return nil, errors.StructuralError("entity without any identities") return nil, errors.StructuralError(fmt.Sprintf("v%d entity without any identities", e.PrimaryKey.Version))
}
// An implementation MUST ensure that a valid direct-key signature is present before using a v6 key.
if e.PrimaryKey.Version == 6 {
if len(directSignatures) == 0 {
return nil, errors.StructuralError("v6 entity without a valid direct-key signature")
}
// Select main direct key signature.
var mainDirectKeySelfSignature *packet.Signature
for _, directSignature := range directSignatures {
if directSignature.SigType == packet.SigTypeDirectSignature &&
directSignature.CheckKeyIdOrFingerprint(e.PrimaryKey) &&
(mainDirectKeySelfSignature == nil ||
directSignature.CreationTime.After(mainDirectKeySelfSignature.CreationTime)) {
mainDirectKeySelfSignature = directSignature
}
}
if mainDirectKeySelfSignature == nil {
return nil, errors.StructuralError("no valid direct-key self-signature for v6 primary key found")
}
// Check that the main self-signature is valid.
err = e.PrimaryKey.VerifyDirectKeySignature(mainDirectKeySelfSignature)
if err != nil {
return nil, errors.StructuralError("invalid direct-key self-signature for v6 primary key")
}
e.SelfSignature = mainDirectKeySelfSignature
e.Signatures = directSignatures
} }
for _, revocation := range revocations { for _, revocation := range revocations {
@@ -672,6 +699,12 @@ func (e *Entity) serializePrivate(w io.Writer, config *packet.Config, reSign boo
return err return err
} }
} }
for _, directSignature := range e.Signatures {
err := directSignature.Serialize(w)
if err != nil {
return err
}
}
for _, ident := range e.Identities { for _, ident := range e.Identities {
err = ident.UserId.Serialize(w) err = ident.UserId.Serialize(w)
if err != nil { if err != nil {
@@ -738,6 +771,12 @@ func (e *Entity) Serialize(w io.Writer) error {
return err return err
} }
} }
for _, directSignature := range e.Signatures {
err := directSignature.Serialize(w)
if err != nil {
return err
}
}
for _, ident := range e.Identities { for _, ident := range e.Identities {
err = ident.UserId.Serialize(w) err = ident.UserId.Serialize(w)
if err != nil { if err != nil {
@@ -840,3 +879,23 @@ func (e *Entity) RevokeSubkey(sk *Subkey, reason packet.ReasonForRevocation, rea
sk.Revocations = append(sk.Revocations, revSig) sk.Revocations = append(sk.Revocations, revSig)
return nil return nil
} }
func (e *Entity) primaryDirectSignature() *packet.Signature {
return e.SelfSignature
}
// PrimarySelfSignature searches the entity for the self-signature that stores key preferences.
// For V4 keys, returns the self-signature of the primary identity, and the identity.
// For V6 keys, returns the latest valid direct-key self-signature, and no identity (nil).
// This self-signature is to be used to check the key expiration,
// algorithm preferences, and so on.
func (e *Entity) PrimarySelfSignature() (*packet.Signature, *Identity) {
if e.PrimaryKey.Version == 6 {
return e.primaryDirectSignature(), nil
}
primaryIdentity := e.PrimaryIdentity()
if primaryIdentity == nil {
return nil, nil
}
return primaryIdentity.SelfSignature, primaryIdentity
}

View File

@@ -37,7 +37,7 @@ func (conf *AEADConfig) Mode() AEADMode {
// ChunkSizeByte returns the byte indicating the chunk size. The effective // ChunkSizeByte returns the byte indicating the chunk size. The effective
// chunk size is computed with the formula uint64(1) << (chunkSizeByte + 6) // chunk size is computed with the formula uint64(1) << (chunkSizeByte + 6)
// limit to 16 = 4 MiB // limit chunkSizeByte to 16 which equals to 2^22 = 4 MiB
// https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-07.html#section-5.13.2 // https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-07.html#section-5.13.2
func (conf *AEADConfig) ChunkSizeByte() byte { func (conf *AEADConfig) ChunkSizeByte() byte {
if conf == nil || conf.ChunkSize == 0 { if conf == nil || conf.ChunkSize == 0 {
@@ -49,8 +49,8 @@ func (conf *AEADConfig) ChunkSizeByte() byte {
switch { switch {
case exponent < 6: case exponent < 6:
exponent = 6 exponent = 6
case exponent > 16: case exponent > 22:
exponent = 16 exponent = 22
} }
return byte(exponent - 6) return byte(exponent - 6)

View File

@@ -3,7 +3,6 @@
package packet package packet
import ( import (
"bytes"
"crypto/cipher" "crypto/cipher"
"encoding/binary" "encoding/binary"
"io" "io"
@@ -15,12 +14,11 @@ import (
type aeadCrypter struct { type aeadCrypter struct {
aead cipher.AEAD aead cipher.AEAD
chunkSize int chunkSize int
initialNonce []byte nonce []byte
associatedData []byte // Chunk-independent associated data associatedData []byte // Chunk-independent associated data
chunkIndex []byte // Chunk counter chunkIndex []byte // Chunk counter
packetTag packetType // SEIP packet (v2) or AEAD Encrypted Data packet packetTag packetType // SEIP packet (v2) or AEAD Encrypted Data packet
bytesProcessed int // Amount of plaintext bytes encrypted/decrypted bytesProcessed int // Amount of plaintext bytes encrypted/decrypted
buffer bytes.Buffer // Buffered bytes across chunks
} }
// computeNonce takes the incremental index and computes an eXclusive OR with // computeNonce takes the incremental index and computes an eXclusive OR with
@@ -28,12 +26,12 @@ type aeadCrypter struct {
// 5.16.1 and 5.16.2). It returns the resulting nonce. // 5.16.1 and 5.16.2). It returns the resulting nonce.
func (wo *aeadCrypter) computeNextNonce() (nonce []byte) { func (wo *aeadCrypter) computeNextNonce() (nonce []byte) {
if wo.packetTag == packetTypeSymmetricallyEncryptedIntegrityProtected { if wo.packetTag == packetTypeSymmetricallyEncryptedIntegrityProtected {
return append(wo.initialNonce, wo.chunkIndex...) return wo.nonce
} }
nonce = make([]byte, len(wo.initialNonce)) nonce = make([]byte, len(wo.nonce))
copy(nonce, wo.initialNonce) copy(nonce, wo.nonce)
offset := len(wo.initialNonce) - 8 offset := len(wo.nonce) - 8
for i := 0; i < 8; i++ { for i := 0; i < 8; i++ {
nonce[i+offset] ^= wo.chunkIndex[i] nonce[i+offset] ^= wo.chunkIndex[i]
} }
@@ -62,8 +60,9 @@ func (wo *aeadCrypter) incrementIndex() error {
type aeadDecrypter struct { type aeadDecrypter struct {
aeadCrypter // Embedded ciphertext opener aeadCrypter // Embedded ciphertext opener
reader io.Reader // 'reader' is a partialLengthReader reader io.Reader // 'reader' is a partialLengthReader
chunkBytes []byte
peekedBytes []byte // Used to detect last chunk peekedBytes []byte // Used to detect last chunk
eof bool buffer []byte // Buffered decrypted bytes
} }
// Read decrypts bytes and reads them into dst. It decrypts when necessary and // Read decrypts bytes and reads them into dst. It decrypts when necessary and
@@ -71,51 +70,45 @@ type aeadDecrypter struct {
// and an error. // and an error.
func (ar *aeadDecrypter) Read(dst []byte) (n int, err error) { func (ar *aeadDecrypter) Read(dst []byte) (n int, err error) {
// Return buffered plaintext bytes from previous calls // Return buffered plaintext bytes from previous calls
if ar.buffer.Len() > 0 { if len(ar.buffer) > 0 {
return ar.buffer.Read(dst) n = copy(dst, ar.buffer)
} ar.buffer = ar.buffer[n:]
return
// Return EOF if we've previously validated the final tag
if ar.eof {
return 0, io.EOF
} }
// Read a chunk // Read a chunk
tagLen := ar.aead.Overhead() tagLen := ar.aead.Overhead()
cipherChunkBuf := new(bytes.Buffer) copy(ar.chunkBytes, ar.peekedBytes) // Copy bytes peeked in previous chunk or in initialization
_, errRead := io.CopyN(cipherChunkBuf, ar.reader, int64(ar.chunkSize+tagLen)) bytesRead, errRead := io.ReadFull(ar.reader, ar.chunkBytes[tagLen:])
cipherChunk := cipherChunkBuf.Bytes() if errRead != nil && errRead != io.EOF && errRead != io.ErrUnexpectedEOF {
if errRead != nil && errRead != io.EOF {
return 0, errRead return 0, errRead
} }
decrypted, errChunk := ar.openChunk(cipherChunk)
if errChunk != nil {
return 0, errChunk
}
// Return decrypted bytes, buffering if necessary if bytesRead > 0 {
if len(dst) < len(decrypted) { ar.peekedBytes = ar.chunkBytes[bytesRead:bytesRead+tagLen]
n = copy(dst, decrypted[:len(dst)])
ar.buffer.Write(decrypted[len(dst):])
} else {
n = copy(dst, decrypted)
}
// Check final authentication tag decrypted, errChunk := ar.openChunk(ar.chunkBytes[:bytesRead])
if errRead == io.EOF {
errChunk := ar.validateFinalTag(ar.peekedBytes)
if errChunk != nil { if errChunk != nil {
return n, errChunk return 0, errChunk
} }
ar.eof = true // Mark EOF for when we've returned all buffered data
// Return decrypted bytes, buffering if necessary
n = copy(dst, decrypted)
ar.buffer = decrypted[n:]
return
} }
return
return 0, io.EOF
} }
// Close is noOp. The final authentication tag of the stream was already // Close checks the final authentication tag of the stream.
// checked in the last Read call. In the future, this function could be used to // In the future, this function could also be used to wipe the reader
// wipe the reader and peeked, decrypted bytes, if necessary. // and peeked & decrypted bytes, if necessary.
func (ar *aeadDecrypter) Close() (err error) { func (ar *aeadDecrypter) Close() (err error) {
errChunk := ar.validateFinalTag(ar.peekedBytes)
if errChunk != nil {
return errChunk
}
return nil return nil
} }
@@ -123,22 +116,15 @@ func (ar *aeadDecrypter) Close() (err error) {
// the underlying plaintext and an error. It accesses peeked bytes from next // the underlying plaintext and an error. It accesses peeked bytes from next
// chunk, to identify the last chunk and decrypt/validate accordingly. // chunk, to identify the last chunk and decrypt/validate accordingly.
func (ar *aeadDecrypter) openChunk(data []byte) ([]byte, error) { func (ar *aeadDecrypter) openChunk(data []byte) ([]byte, error) {
tagLen := ar.aead.Overhead()
// Restore carried bytes from last call
chunkExtra := append(ar.peekedBytes, data...)
// 'chunk' contains encrypted bytes, followed by an authentication tag.
chunk := chunkExtra[:len(chunkExtra)-tagLen]
ar.peekedBytes = chunkExtra[len(chunkExtra)-tagLen:]
adata := ar.associatedData adata := ar.associatedData
if ar.aeadCrypter.packetTag == packetTypeAEADEncrypted { if ar.aeadCrypter.packetTag == packetTypeAEADEncrypted {
adata = append(ar.associatedData, ar.chunkIndex...) adata = append(ar.associatedData, ar.chunkIndex...)
} }
nonce := ar.computeNextNonce() nonce := ar.computeNextNonce()
plainChunk, err := ar.aead.Open(nil, nonce, chunk, adata) plainChunk, err := ar.aead.Open(data[:0:len(data)], nonce, data, adata)
if err != nil { if err != nil {
return nil, err return nil, errors.ErrAEADTagVerification
} }
ar.bytesProcessed += len(plainChunk) ar.bytesProcessed += len(plainChunk)
if err = ar.aeadCrypter.incrementIndex(); err != nil { if err = ar.aeadCrypter.incrementIndex(); err != nil {
@@ -163,9 +149,8 @@ func (ar *aeadDecrypter) validateFinalTag(tag []byte) error {
// ... and total number of encrypted octets // ... and total number of encrypted octets
adata = append(adata, amountBytes...) adata = append(adata, amountBytes...)
nonce := ar.computeNextNonce() nonce := ar.computeNextNonce()
_, err := ar.aead.Open(nil, nonce, tag, adata) if _, err := ar.aead.Open(nil, nonce, tag, adata); err != nil {
if err != nil { return errors.ErrAEADTagVerification
return err
} }
return nil return nil
} }
@@ -175,27 +160,29 @@ func (ar *aeadDecrypter) validateFinalTag(tag []byte) error {
type aeadEncrypter struct { type aeadEncrypter struct {
aeadCrypter // Embedded plaintext sealer aeadCrypter // Embedded plaintext sealer
writer io.WriteCloser // 'writer' is a partialLengthWriter writer io.WriteCloser // 'writer' is a partialLengthWriter
chunkBytes []byte
offset int
} }
// Write encrypts and writes bytes. It encrypts when necessary and buffers extra // Write encrypts and writes bytes. It encrypts when necessary and buffers extra
// plaintext bytes for next call. When the stream is finished, Close() MUST be // plaintext bytes for next call. When the stream is finished, Close() MUST be
// called to append the final tag. // called to append the final tag.
func (aw *aeadEncrypter) Write(plaintextBytes []byte) (n int, err error) { func (aw *aeadEncrypter) Write(plaintextBytes []byte) (n int, err error) {
// Append plaintextBytes to existing buffered bytes for n != len(plaintextBytes) {
n, err = aw.buffer.Write(plaintextBytes) copied := copy(aw.chunkBytes[aw.offset:aw.chunkSize], plaintextBytes[n:])
if err != nil { n += copied
return n, err aw.offset += copied
}
// Encrypt and write chunks if aw.offset == aw.chunkSize {
for aw.buffer.Len() >= aw.chunkSize { encryptedChunk, err := aw.sealChunk(aw.chunkBytes[:aw.offset])
plainChunk := aw.buffer.Next(aw.chunkSize) if err != nil {
encryptedChunk, err := aw.sealChunk(plainChunk) return n, err
if err != nil { }
return n, err _, err = aw.writer.Write(encryptedChunk)
} if err != nil {
_, err = aw.writer.Write(encryptedChunk) return n, err
if err != nil { }
return n, err aw.offset = 0
} }
} }
return return
@@ -207,9 +194,8 @@ func (aw *aeadEncrypter) Write(plaintextBytes []byte) (n int, err error) {
func (aw *aeadEncrypter) Close() (err error) { func (aw *aeadEncrypter) Close() (err error) {
// Encrypt and write a chunk if there's buffered data left, or if we haven't // Encrypt and write a chunk if there's buffered data left, or if we haven't
// written any chunks yet. // written any chunks yet.
if aw.buffer.Len() > 0 || aw.bytesProcessed == 0 { if aw.offset > 0 || aw.bytesProcessed == 0 {
plainChunk := aw.buffer.Bytes() lastEncryptedChunk, err := aw.sealChunk(aw.chunkBytes[:aw.offset])
lastEncryptedChunk, err := aw.sealChunk(plainChunk)
if err != nil { if err != nil {
return err return err
} }
@@ -255,7 +241,7 @@ func (aw *aeadEncrypter) sealChunk(data []byte) ([]byte, error) {
} }
nonce := aw.computeNextNonce() nonce := aw.computeNextNonce()
encrypted := aw.aead.Seal(nil, nonce, data, adata) encrypted := aw.aead.Seal(data[:0], nonce, data, adata)
aw.bytesProcessed += len(data) aw.bytesProcessed += len(data)
if err := aw.aeadCrypter.incrementIndex(); err != nil { if err := aw.aeadCrypter.incrementIndex(); err != nil {
return nil, err return nil, err

View File

@@ -65,24 +65,28 @@ func (ae *AEADEncrypted) decrypt(key []byte) (io.ReadCloser, error) {
blockCipher := ae.cipher.new(key) blockCipher := ae.cipher.new(key)
aead := ae.mode.new(blockCipher) aead := ae.mode.new(blockCipher)
// Carry the first tagLen bytes // Carry the first tagLen bytes
chunkSize := decodeAEADChunkSize(ae.chunkSizeByte)
tagLen := ae.mode.TagLength() tagLen := ae.mode.TagLength()
peekedBytes := make([]byte, tagLen) chunkBytes := make([]byte, chunkSize+tagLen*2)
peekedBytes := chunkBytes[chunkSize+tagLen:]
n, err := io.ReadFull(ae.Contents, peekedBytes) n, err := io.ReadFull(ae.Contents, peekedBytes)
if n < tagLen || (err != nil && err != io.EOF) { if n < tagLen || (err != nil && err != io.EOF) {
return nil, errors.AEADError("Not enough data to decrypt:" + err.Error()) return nil, errors.AEADError("Not enough data to decrypt:" + err.Error())
} }
chunkSize := decodeAEADChunkSize(ae.chunkSizeByte)
return &aeadDecrypter{ return &aeadDecrypter{
aeadCrypter: aeadCrypter{ aeadCrypter: aeadCrypter{
aead: aead, aead: aead,
chunkSize: chunkSize, chunkSize: chunkSize,
initialNonce: ae.initialNonce, nonce: ae.initialNonce,
associatedData: ae.associatedData(), associatedData: ae.associatedData(),
chunkIndex: make([]byte, 8), chunkIndex: make([]byte, 8),
packetTag: packetTypeAEADEncrypted, packetTag: packetTypeAEADEncrypted,
}, },
reader: ae.Contents, reader: ae.Contents,
peekedBytes: peekedBytes}, nil chunkBytes: chunkBytes,
peekedBytes: peekedBytes,
}, nil
} }
// associatedData for chunks: tag, version, cipher, mode, chunk size byte // associatedData for chunks: tag, version, cipher, mode, chunk size byte

View File

@@ -8,9 +8,10 @@ import (
"compress/bzip2" "compress/bzip2"
"compress/flate" "compress/flate"
"compress/zlib" "compress/zlib"
"github.com/ProtonMail/go-crypto/openpgp/errors"
"io" "io"
"strconv" "strconv"
"github.com/ProtonMail/go-crypto/openpgp/errors"
) )
// Compressed represents a compressed OpenPGP packet. The decompressed contents // Compressed represents a compressed OpenPGP packet. The decompressed contents
@@ -39,6 +40,37 @@ type CompressionConfig struct {
Level int Level int
} }
// decompressionReader ensures that the whole compression packet is read.
type decompressionReader struct {
compressed io.Reader
decompressed io.ReadCloser
readAll bool
}
func newDecompressionReader(r io.Reader, decompressor io.ReadCloser) *decompressionReader {
return &decompressionReader{
compressed: r,
decompressed: decompressor,
}
}
func (dr *decompressionReader) Read(data []byte) (n int, err error) {
if dr.readAll {
return 0, io.EOF
}
n, err = dr.decompressed.Read(data)
if err == io.EOF {
dr.readAll = true
// Close the decompressor.
if errDec := dr.decompressed.Close(); errDec != nil {
return n, errDec
}
// Consume all remaining data from the compressed packet.
consumeAll(dr.compressed)
}
return n, err
}
func (c *Compressed) parse(r io.Reader) error { func (c *Compressed) parse(r io.Reader) error {
var buf [1]byte var buf [1]byte
_, err := readFull(r, buf[:]) _, err := readFull(r, buf[:])
@@ -50,11 +82,15 @@ func (c *Compressed) parse(r io.Reader) error {
case 0: case 0:
c.Body = r c.Body = r
case 1: case 1:
c.Body = flate.NewReader(r) c.Body = newDecompressionReader(r, flate.NewReader(r))
case 2: case 2:
c.Body, err = zlib.NewReader(r) decompressor, err := zlib.NewReader(r)
if err != nil {
return err
}
c.Body = newDecompressionReader(r, decompressor)
case 3: case 3:
c.Body = bzip2.NewReader(r) c.Body = newDecompressionReader(r, io.NopCloser(bzip2.NewReader(r)))
default: default:
err = errors.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0]))) err = errors.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0])))
} }
@@ -62,6 +98,16 @@ func (c *Compressed) parse(r io.Reader) error {
return err return err
} }
// LimitedBodyReader wraps the provided body reader with a limiter that restricts
// the number of bytes read to the specified limit.
// If limit is nil, the reader is unbounded.
func (c *Compressed) LimitedBodyReader(limit *int64) io.Reader {
if limit == nil {
return c.Body
}
return &LimitReader{R: c.Body, N: *limit}
}
// compressedWriterCloser represents the serialized compression stream // compressedWriterCloser represents the serialized compression stream
// header and the compressor. Its Close() method ensures that both the // header and the compressor. Its Close() method ensures that both the
// compressor and serialized stream header are closed. Its Write() // compressor and serialized stream header are closed. Its Write()
@@ -123,3 +169,24 @@ func SerializeCompressed(w io.WriteCloser, algo CompressionAlgo, cc *Compression
return return
} }
// LimitReader is an io.Reader that fails with MessageToLarge if read bytes exceed N.
type LimitReader struct {
R io.Reader // underlying reader
N int64 // max bytes allowed
}
func (l *LimitReader) Read(p []byte) (int, error) {
if l.N <= 0 {
return 0, errors.ErrMessageTooLarge
}
n, err := l.R.Read(p)
l.N -= int64(n)
if err == nil && l.N <= 0 {
err = errors.ErrMessageTooLarge
}
return n, err
}

View File

@@ -14,6 +14,34 @@ import (
"github.com/ProtonMail/go-crypto/openpgp/s2k" "github.com/ProtonMail/go-crypto/openpgp/s2k"
) )
var (
defaultRejectPublicKeyAlgorithms = map[PublicKeyAlgorithm]bool{
PubKeyAlgoElGamal: true,
PubKeyAlgoDSA: true,
}
defaultRejectHashAlgorithms = map[crypto.Hash]bool{
crypto.MD5: true,
crypto.RIPEMD160: true,
}
defaultRejectMessageHashAlgorithms = map[crypto.Hash]bool{
crypto.SHA1: true,
crypto.MD5: true,
crypto.RIPEMD160: true,
}
defaultRejectCurves = map[Curve]bool{
CurveSecP256k1: true,
}
)
// A global feature flag to indicate v5 support.
// Can be set via a build tag, e.g.: `go build -tags v5 ./...`
// If the build tag is missing config_v5.go will set it to true.
//
// Disables parsing of v5 keys and v5 signatures.
// These are non-standard entities, which in the crypto-refresh have been superseded
// by v6 keys, v6 signatures and SEIPDv2 encrypted data, respectively.
var V5Disabled = false
// Config collects a number of parameters along with sensible defaults. // Config collects a number of parameters along with sensible defaults.
// A nil *Config is valid and results in all default values. // A nil *Config is valid and results in all default values.
type Config struct { type Config struct {
@@ -73,9 +101,16 @@ type Config struct {
// **Note: using this option may break compatibility with other OpenPGP // **Note: using this option may break compatibility with other OpenPGP
// implementations, as well as future versions of this library.** // implementations, as well as future versions of this library.**
AEADConfig *AEADConfig AEADConfig *AEADConfig
// V5Keys configures version 5 key generation. If false, this package still // V6Keys configures version 6 key generation. If false, this package still
// supports version 5 keys, but produces version 4 keys. // supports version 6 keys, but produces version 4 keys.
V5Keys bool V6Keys bool
// Minimum RSA key size allowed for key generation and message signing, verification and encryption.
MinRSABits uint16
// Reject insecure algorithms, only works with v2 api
RejectPublicKeyAlgorithms map[PublicKeyAlgorithm]bool
RejectHashAlgorithms map[crypto.Hash]bool
RejectMessageHashAlgorithms map[crypto.Hash]bool
RejectCurves map[Curve]bool
// "The validity period of the key. This is the number of seconds after // "The validity period of the key. This is the number of seconds after
// the key creation time that the key expires. If this is not present // the key creation time that the key expires. If this is not present
// or has a value of zero, the key never expires. This is found only on // or has a value of zero, the key never expires. This is found only on
@@ -104,12 +139,50 @@ type Config struct {
// might be no other way than to tolerate the missing MDC. Setting this flag, allows this // might be no other way than to tolerate the missing MDC. Setting this flag, allows this
// mode of operation. It should be considered a measure of last resort. // mode of operation. It should be considered a measure of last resort.
InsecureAllowUnauthenticatedMessages bool InsecureAllowUnauthenticatedMessages bool
// InsecureAllowDecryptionWithSigningKeys allows decryption with keys marked as signing keys in the v2 API.
// This setting is potentially insecure, but it is needed as some libraries
// ignored key flags when selecting a key for encryption.
// Not relevant for the v1 API, as all keys were allowed in decryption.
InsecureAllowDecryptionWithSigningKeys bool
// KnownNotations is a map of Notation Data names to bools, which controls // KnownNotations is a map of Notation Data names to bools, which controls
// the notation names that are allowed to be present in critical Notation Data // the notation names that are allowed to be present in critical Notation Data
// signature subpackets. // signature subpackets.
KnownNotations map[string]bool KnownNotations map[string]bool
// SignatureNotations is a list of Notations to be added to any signatures. // SignatureNotations is a list of Notations to be added to any signatures.
SignatureNotations []*Notation SignatureNotations []*Notation
// CheckIntendedRecipients controls, whether the OpenPGP Intended Recipient Fingerprint feature
// should be enabled for encryption and decryption.
// (See https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-12.html#name-intended-recipient-fingerpr).
// When the flag is set, encryption produces Intended Recipient Fingerprint signature sub-packets and decryption
// checks whether the key it was encrypted to is one of the included fingerprints in the signature.
// If the flag is disabled, no Intended Recipient Fingerprint sub-packets are created or checked.
// The default behavior, when the config or flag is nil, is to enable the feature.
CheckIntendedRecipients *bool
// CacheSessionKey controls if decryption should return the session key used for decryption.
// If the flag is set, the session key is cached in the message details struct.
CacheSessionKey bool
// CheckPacketSequence is a flag that controls if the pgp message reader should strictly check
// that the packet sequence conforms with the grammar mandated by rfc4880.
// The default behavior, when the config or flag is nil, is to check the packet sequence.
CheckPacketSequence *bool
// NonDeterministicSignaturesViaNotation is a flag to enable randomization of signatures.
// If true, a salt notation is used to randomize signatures generated by v4 and v5 keys
// (v6 signatures are always non-deterministic, by design).
// This protects EdDSA signatures from potentially leaking the secret key in case of faults (i.e. bitflips) which, in principle, could occur
// during the signing computation. It is added to signatures of any algo for simplicity, and as it may also serve as protection in case of
// weaknesses in the hash algo, potentially hindering e.g. some chosen-prefix attacks.
// The default behavior, when the config or flag is nil, is to enable the feature.
NonDeterministicSignaturesViaNotation *bool
// InsecureAllowAllKeyFlagsWhenMissing determines how a key without valid key flags is handled.
// When set to true, a key without flags is treated as if all flags are enabled.
// This behavior is consistent with GPG.
InsecureAllowAllKeyFlagsWhenMissing bool
// MaxDecompressedMessageSize specifies the maximum number of bytes that can be
// read from a compressed packet. This serves as an upper limit to prevent
// excessively large decompressed messages.
MaxDecompressedMessageSize *int64
} }
func (c *Config) Random() io.Reader { func (c *Config) Random() io.Reader {
@@ -197,7 +270,7 @@ func (c *Config) S2K() *s2k.Config {
return nil return nil
} }
// for backwards compatibility // for backwards compatibility
if c != nil && c.S2KCount > 0 && c.S2KConfig == nil { if c.S2KCount > 0 && c.S2KConfig == nil {
return &s2k.Config{ return &s2k.Config{
S2KCount: c.S2KCount, S2KCount: c.S2KCount,
} }
@@ -233,6 +306,13 @@ func (c *Config) AllowUnauthenticatedMessages() bool {
return c.InsecureAllowUnauthenticatedMessages return c.InsecureAllowUnauthenticatedMessages
} }
func (c *Config) AllowDecryptionWithSigningKeys() bool {
if c == nil {
return false
}
return c.InsecureAllowDecryptionWithSigningKeys
}
func (c *Config) KnownNotation(notationName string) bool { func (c *Config) KnownNotation(notationName string) bool {
if c == nil { if c == nil {
return false return false
@@ -246,3 +326,109 @@ func (c *Config) Notations() []*Notation {
} }
return c.SignatureNotations return c.SignatureNotations
} }
func (c *Config) V6() bool {
if c == nil {
return false
}
return c.V6Keys
}
func (c *Config) IntendedRecipients() bool {
if c == nil || c.CheckIntendedRecipients == nil {
return true
}
return *c.CheckIntendedRecipients
}
func (c *Config) RetrieveSessionKey() bool {
if c == nil {
return false
}
return c.CacheSessionKey
}
func (c *Config) MinimumRSABits() uint16 {
if c == nil || c.MinRSABits == 0 {
return 2047
}
return c.MinRSABits
}
func (c *Config) RejectPublicKeyAlgorithm(alg PublicKeyAlgorithm) bool {
var rejectedAlgorithms map[PublicKeyAlgorithm]bool
if c == nil || c.RejectPublicKeyAlgorithms == nil {
// Default
rejectedAlgorithms = defaultRejectPublicKeyAlgorithms
} else {
rejectedAlgorithms = c.RejectPublicKeyAlgorithms
}
return rejectedAlgorithms[alg]
}
func (c *Config) RejectHashAlgorithm(hash crypto.Hash) bool {
var rejectedAlgorithms map[crypto.Hash]bool
if c == nil || c.RejectHashAlgorithms == nil {
// Default
rejectedAlgorithms = defaultRejectHashAlgorithms
} else {
rejectedAlgorithms = c.RejectHashAlgorithms
}
return rejectedAlgorithms[hash]
}
func (c *Config) RejectMessageHashAlgorithm(hash crypto.Hash) bool {
var rejectedAlgorithms map[crypto.Hash]bool
if c == nil || c.RejectMessageHashAlgorithms == nil {
// Default
rejectedAlgorithms = defaultRejectMessageHashAlgorithms
} else {
rejectedAlgorithms = c.RejectMessageHashAlgorithms
}
return rejectedAlgorithms[hash]
}
func (c *Config) RejectCurve(curve Curve) bool {
var rejectedCurve map[Curve]bool
if c == nil || c.RejectCurves == nil {
// Default
rejectedCurve = defaultRejectCurves
} else {
rejectedCurve = c.RejectCurves
}
return rejectedCurve[curve]
}
func (c *Config) StrictPacketSequence() bool {
if c == nil || c.CheckPacketSequence == nil {
return true
}
return *c.CheckPacketSequence
}
func (c *Config) RandomizeSignaturesViaNotation() bool {
if c == nil || c.NonDeterministicSignaturesViaNotation == nil {
return true
}
return *c.NonDeterministicSignaturesViaNotation
}
func (c *Config) AllowAllKeyFlagsWhenMissing() bool {
if c == nil {
return false
}
return c.InsecureAllowAllKeyFlagsWhenMissing
}
func (c *Config) DecompressedMessageSizeLimit() *int64 {
if c == nil {
return nil
}
return c.MaxDecompressedMessageSize
}
// BoolPointer is a helper function to set a boolean pointer in the Config.
// e.g., config.CheckPacketSequence = BoolPointer(true)
func BoolPointer(value bool) *bool {
return &value
}

View File

@@ -0,0 +1,7 @@
//go:build !v5
package packet
func init() {
V5Disabled = true
}

View File

@@ -5,9 +5,11 @@
package packet package packet
import ( import (
"bytes"
"crypto" "crypto"
"crypto/rsa" "crypto/rsa"
"encoding/binary" "encoding/binary"
"encoding/hex"
"io" "io"
"math/big" "math/big"
"strconv" "strconv"
@@ -16,32 +18,85 @@ import (
"github.com/ProtonMail/go-crypto/openpgp/elgamal" "github.com/ProtonMail/go-crypto/openpgp/elgamal"
"github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/errors"
"github.com/ProtonMail/go-crypto/openpgp/internal/encoding" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding"
"github.com/ProtonMail/go-crypto/openpgp/x25519"
"github.com/ProtonMail/go-crypto/openpgp/x448"
) )
const encryptedKeyVersion = 3
// EncryptedKey represents a public-key encrypted session key. See RFC 4880, // EncryptedKey represents a public-key encrypted session key. See RFC 4880,
// section 5.1. // section 5.1.
type EncryptedKey struct { type EncryptedKey struct {
KeyId uint64 Version int
Algo PublicKeyAlgorithm KeyId uint64
CipherFunc CipherFunction // only valid after a successful Decrypt for a v3 packet KeyVersion int // v6
Key []byte // only valid after a successful Decrypt KeyFingerprint []byte // v6
Algo PublicKeyAlgorithm
CipherFunc CipherFunction // only valid after a successful Decrypt for a v3 packet
Key []byte // only valid after a successful Decrypt
encryptedMPI1, encryptedMPI2 encoding.Field encryptedMPI1, encryptedMPI2 encoding.Field
ephemeralPublicX25519 *x25519.PublicKey // used for x25519
ephemeralPublicX448 *x448.PublicKey // used for x448
encryptedSession []byte // used for x25519 and x448
} }
func (e *EncryptedKey) parse(r io.Reader) (err error) { func (e *EncryptedKey) parse(r io.Reader) (err error) {
var buf [10]byte var buf [8]byte
_, err = readFull(r, buf[:]) _, err = readFull(r, buf[:versionSize])
if err != nil { if err != nil {
return return
} }
if buf[0] != encryptedKeyVersion { e.Version = int(buf[0])
if e.Version != 3 && e.Version != 6 {
return errors.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0]))) return errors.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0])))
} }
e.KeyId = binary.BigEndian.Uint64(buf[1:9]) if e.Version == 6 {
e.Algo = PublicKeyAlgorithm(buf[9]) //Read a one-octet size of the following two fields.
if _, err = readFull(r, buf[:1]); err != nil {
return
}
// The size may also be zero, and the key version and
// fingerprint omitted for an "anonymous recipient"
if buf[0] != 0 {
// non-anonymous case
_, err = readFull(r, buf[:versionSize])
if err != nil {
return
}
e.KeyVersion = int(buf[0])
if e.KeyVersion != 4 && e.KeyVersion != 6 {
return errors.UnsupportedError("unknown public key version " + strconv.Itoa(e.KeyVersion))
}
var fingerprint []byte
if e.KeyVersion == 6 {
fingerprint = make([]byte, fingerprintSizeV6)
} else if e.KeyVersion == 4 {
fingerprint = make([]byte, fingerprintSize)
}
_, err = readFull(r, fingerprint)
if err != nil {
return
}
e.KeyFingerprint = fingerprint
if e.KeyVersion == 6 {
e.KeyId = binary.BigEndian.Uint64(e.KeyFingerprint[:keyIdSize])
} else if e.KeyVersion == 4 {
e.KeyId = binary.BigEndian.Uint64(e.KeyFingerprint[fingerprintSize-keyIdSize : fingerprintSize])
}
}
} else {
_, err = readFull(r, buf[:8])
if err != nil {
return
}
e.KeyId = binary.BigEndian.Uint64(buf[:keyIdSize])
}
_, err = readFull(r, buf[:1])
if err != nil {
return
}
e.Algo = PublicKeyAlgorithm(buf[0])
var cipherFunction byte
switch e.Algo { switch e.Algo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly: case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
e.encryptedMPI1 = new(encoding.MPI) e.encryptedMPI1 = new(encoding.MPI)
@@ -68,26 +123,39 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) {
if _, err = e.encryptedMPI2.ReadFrom(r); err != nil { if _, err = e.encryptedMPI2.ReadFrom(r); err != nil {
return return
} }
case PubKeyAlgoX25519:
e.ephemeralPublicX25519, e.encryptedSession, cipherFunction, err = x25519.DecodeFields(r, e.Version == 6)
if err != nil {
return
}
case PubKeyAlgoX448:
e.ephemeralPublicX448, e.encryptedSession, cipherFunction, err = x448.DecodeFields(r, e.Version == 6)
if err != nil {
return
}
} }
if e.Version < 6 {
switch e.Algo {
case PubKeyAlgoX25519, PubKeyAlgoX448:
e.CipherFunc = CipherFunction(cipherFunction)
// Check for validiy is in the Decrypt method
}
}
_, err = consumeAll(r) _, err = consumeAll(r)
return return
} }
func checksumKeyMaterial(key []byte) uint16 {
var checksum uint16
for _, v := range key {
checksum += uint16(v)
}
return checksum
}
// Decrypt decrypts an encrypted session key with the given private key. The // Decrypt decrypts an encrypted session key with the given private key. The
// private key must have been decrypted first. // private key must have been decrypted first.
// If config is nil, sensible defaults will be used. // If config is nil, sensible defaults will be used.
func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error { func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error {
if e.KeyId != 0 && e.KeyId != priv.KeyId { if e.Version < 6 && e.KeyId != 0 && e.KeyId != priv.KeyId {
return errors.InvalidArgumentError("cannot decrypt encrypted session key for key id " + strconv.FormatUint(e.KeyId, 16) + " with private key id " + strconv.FormatUint(priv.KeyId, 16)) return errors.InvalidArgumentError("cannot decrypt encrypted session key for key id " + strconv.FormatUint(e.KeyId, 16) + " with private key id " + strconv.FormatUint(priv.KeyId, 16))
} }
if e.Version == 6 && e.KeyVersion != 0 && !bytes.Equal(e.KeyFingerprint, priv.Fingerprint) {
return errors.InvalidArgumentError("cannot decrypt encrypted session key for key fingerprint " + hex.EncodeToString(e.KeyFingerprint) + " with private key fingerprint " + hex.EncodeToString(priv.Fingerprint))
}
if e.Algo != priv.PubKeyAlgo { if e.Algo != priv.PubKeyAlgo {
return errors.InvalidArgumentError("cannot decrypt encrypted session key of type " + strconv.Itoa(int(e.Algo)) + " with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo))) return errors.InvalidArgumentError("cannot decrypt encrypted session key of type " + strconv.Itoa(int(e.Algo)) + " with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
} }
@@ -113,52 +181,116 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error {
vsG := e.encryptedMPI1.Bytes() vsG := e.encryptedMPI1.Bytes()
m := e.encryptedMPI2.Bytes() m := e.encryptedMPI2.Bytes()
oid := priv.PublicKey.oid.EncodedBytes() oid := priv.PublicKey.oid.EncodedBytes()
b, err = ecdh.Decrypt(priv.PrivateKey.(*ecdh.PrivateKey), vsG, m, oid, priv.PublicKey.Fingerprint[:]) fp := priv.PublicKey.Fingerprint[:]
if priv.PublicKey.Version == 5 {
// For v5 the, the fingerprint must be restricted to 20 bytes
fp = fp[:20]
}
b, err = ecdh.Decrypt(priv.PrivateKey.(*ecdh.PrivateKey), vsG, m, oid, fp)
case PubKeyAlgoX25519:
b, err = x25519.Decrypt(priv.PrivateKey.(*x25519.PrivateKey), e.ephemeralPublicX25519, e.encryptedSession)
case PubKeyAlgoX448:
b, err = x448.Decrypt(priv.PrivateKey.(*x448.PrivateKey), e.ephemeralPublicX448, e.encryptedSession)
default: default:
err = errors.InvalidArgumentError("cannot decrypt encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo))) err = errors.InvalidArgumentError("cannot decrypt encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo)))
} }
if err != nil { if err != nil {
return err return err
} }
e.CipherFunc = CipherFunction(b[0]) var key []byte
if !e.CipherFunc.IsSupported() { switch priv.PubKeyAlgo {
return errors.UnsupportedError("unsupported encryption function") case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH:
keyOffset := 0
if e.Version < 6 {
e.CipherFunc = CipherFunction(b[0])
keyOffset = 1
if !e.CipherFunc.IsSupported() {
return errors.UnsupportedError("unsupported encryption function")
}
}
key, err = decodeChecksumKey(b[keyOffset:])
if err != nil {
return err
}
case PubKeyAlgoX25519, PubKeyAlgoX448:
if e.Version < 6 {
switch e.CipherFunc {
case CipherAES128, CipherAES192, CipherAES256:
break
default:
return errors.StructuralError("v3 PKESK mandates AES as cipher function for x25519 and x448")
}
}
key = b[:]
default:
return errors.UnsupportedError("unsupported algorithm for decryption")
} }
e.Key = key
e.Key = b[1 : len(b)-2]
expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1])
checksum := checksumKeyMaterial(e.Key)
if checksum != expectedChecksum {
return errors.StructuralError("EncryptedKey checksum incorrect")
}
return nil return nil
} }
// Serialize writes the encrypted key packet, e, to w. // Serialize writes the encrypted key packet, e, to w.
func (e *EncryptedKey) Serialize(w io.Writer) error { func (e *EncryptedKey) Serialize(w io.Writer) error {
var mpiLen int var encodedLength int
switch e.Algo { switch e.Algo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly: case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
mpiLen = int(e.encryptedMPI1.EncodedLength()) encodedLength = int(e.encryptedMPI1.EncodedLength())
case PubKeyAlgoElGamal: case PubKeyAlgoElGamal:
mpiLen = int(e.encryptedMPI1.EncodedLength()) + int(e.encryptedMPI2.EncodedLength()) encodedLength = int(e.encryptedMPI1.EncodedLength()) + int(e.encryptedMPI2.EncodedLength())
case PubKeyAlgoECDH: case PubKeyAlgoECDH:
mpiLen = int(e.encryptedMPI1.EncodedLength()) + int(e.encryptedMPI2.EncodedLength()) encodedLength = int(e.encryptedMPI1.EncodedLength()) + int(e.encryptedMPI2.EncodedLength())
case PubKeyAlgoX25519:
encodedLength = x25519.EncodedFieldsLength(e.encryptedSession, e.Version == 6)
case PubKeyAlgoX448:
encodedLength = x448.EncodedFieldsLength(e.encryptedSession, e.Version == 6)
default: default:
return errors.InvalidArgumentError("don't know how to serialize encrypted key type " + strconv.Itoa(int(e.Algo))) return errors.InvalidArgumentError("don't know how to serialize encrypted key type " + strconv.Itoa(int(e.Algo)))
} }
err := serializeHeader(w, packetTypeEncryptedKey, 1 /* version */ +8 /* key id */ +1 /* algo */ +mpiLen) packetLen := versionSize /* version */ + keyIdSize /* key id */ + algorithmSize /* algo */ + encodedLength
if e.Version == 6 {
packetLen = versionSize /* version */ + algorithmSize /* algo */ + encodedLength + keyVersionSize /* key version */
if e.KeyVersion == 6 {
packetLen += fingerprintSizeV6
} else if e.KeyVersion == 4 {
packetLen += fingerprintSize
}
}
err := serializeHeader(w, packetTypeEncryptedKey, packetLen)
if err != nil { if err != nil {
return err return err
} }
w.Write([]byte{encryptedKeyVersion}) _, err = w.Write([]byte{byte(e.Version)})
binary.Write(w, binary.BigEndian, e.KeyId) if err != nil {
w.Write([]byte{byte(e.Algo)}) return err
}
if e.Version == 6 {
_, err = w.Write([]byte{byte(e.KeyVersion)})
if err != nil {
return err
}
// The key version number may also be zero,
// and the fingerprint omitted
if e.KeyVersion != 0 {
_, err = w.Write(e.KeyFingerprint)
if err != nil {
return err
}
}
} else {
// Write KeyID
err = binary.Write(w, binary.BigEndian, e.KeyId)
if err != nil {
return err
}
}
_, err = w.Write([]byte{byte(e.Algo)})
if err != nil {
return err
}
switch e.Algo { switch e.Algo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly: case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
@@ -176,34 +308,115 @@ func (e *EncryptedKey) Serialize(w io.Writer) error {
} }
_, err := w.Write(e.encryptedMPI2.EncodedBytes()) _, err := w.Write(e.encryptedMPI2.EncodedBytes())
return err return err
case PubKeyAlgoX25519:
err := x25519.EncodeFields(w, e.ephemeralPublicX25519, e.encryptedSession, byte(e.CipherFunc), e.Version == 6)
return err
case PubKeyAlgoX448:
err := x448.EncodeFields(w, e.ephemeralPublicX448, e.encryptedSession, byte(e.CipherFunc), e.Version == 6)
return err
default: default:
panic("internal error") panic("internal error")
} }
} }
// SerializeEncryptedKey serializes an encrypted key packet to w that contains // SerializeEncryptedKeyAEAD serializes an encrypted key packet to w that contains
// key, encrypted to pub. // key, encrypted to pub.
// If aeadSupported is set, PKESK v6 is used, otherwise v3.
// Note: aeadSupported MUST match the value passed to SerializeSymmetricallyEncrypted.
// If config is nil, sensible defaults will be used. // If config is nil, sensible defaults will be used.
func SerializeEncryptedKey(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, key []byte, config *Config) error { func SerializeEncryptedKeyAEAD(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, aeadSupported bool, key []byte, config *Config) error {
var buf [10]byte return SerializeEncryptedKeyAEADwithHiddenOption(w, pub, cipherFunc, aeadSupported, key, false, config)
buf[0] = encryptedKeyVersion }
binary.BigEndian.PutUint64(buf[1:9], pub.KeyId)
buf[9] = byte(pub.PubKeyAlgo)
keyBlock := make([]byte, 1 /* cipher type */ +len(key)+2 /* checksum */) // SerializeEncryptedKeyAEADwithHiddenOption serializes an encrypted key packet to w that contains
keyBlock[0] = byte(cipherFunc) // key, encrypted to pub.
copy(keyBlock[1:], key) // Offers the hidden flag option to indicated if the PKESK packet should include a wildcard KeyID.
checksum := checksumKeyMaterial(key) // If aeadSupported is set, PKESK v6 is used, otherwise v3.
keyBlock[1+len(key)] = byte(checksum >> 8) // Note: aeadSupported MUST match the value passed to SerializeSymmetricallyEncrypted.
keyBlock[1+len(key)+1] = byte(checksum) // If config is nil, sensible defaults will be used.
func SerializeEncryptedKeyAEADwithHiddenOption(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, aeadSupported bool, key []byte, hidden bool, config *Config) error {
var buf [36]byte // max possible header size is v6
lenHeaderWritten := versionSize
version := 3
if aeadSupported {
version = 6
}
// An implementation MUST NOT generate ElGamal v6 PKESKs.
if version == 6 && pub.PubKeyAlgo == PubKeyAlgoElGamal {
return errors.InvalidArgumentError("ElGamal v6 PKESK are not allowed")
}
// In v3 PKESKs, for x25519 and x448, mandate using AES
if version == 3 && (pub.PubKeyAlgo == PubKeyAlgoX25519 || pub.PubKeyAlgo == PubKeyAlgoX448) {
switch cipherFunc {
case CipherAES128, CipherAES192, CipherAES256:
break
default:
return errors.InvalidArgumentError("v3 PKESK mandates AES for x25519 and x448")
}
}
buf[0] = byte(version)
// If hidden is set, the key should be hidden
// An implementation MAY accept or use a Key ID of all zeros,
// or a key version of zero and no key fingerprint, to hide the intended decryption key.
// See Section 5.1.8. in the open pgp crypto refresh
if version == 6 {
if !hidden {
// A one-octet size of the following two fields.
buf[1] = byte(keyVersionSize + len(pub.Fingerprint))
// A one octet key version number.
buf[2] = byte(pub.Version)
lenHeaderWritten += keyVersionSize + 1
// The fingerprint of the public key
copy(buf[lenHeaderWritten:lenHeaderWritten+len(pub.Fingerprint)], pub.Fingerprint)
lenHeaderWritten += len(pub.Fingerprint)
} else {
// The size may also be zero, and the key version
// and fingerprint omitted for an "anonymous recipient"
buf[1] = 0
lenHeaderWritten += 1
}
} else {
if !hidden {
binary.BigEndian.PutUint64(buf[versionSize:(versionSize+keyIdSize)], pub.KeyId)
}
lenHeaderWritten += keyIdSize
}
buf[lenHeaderWritten] = byte(pub.PubKeyAlgo)
lenHeaderWritten += algorithmSize
var keyBlock []byte
switch pub.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH:
lenKeyBlock := len(key) + 2
if version < 6 {
lenKeyBlock += 1 // cipher type included
}
keyBlock = make([]byte, lenKeyBlock)
keyOffset := 0
if version < 6 {
keyBlock[0] = byte(cipherFunc)
keyOffset = 1
}
encodeChecksumKey(keyBlock[keyOffset:], key)
case PubKeyAlgoX25519, PubKeyAlgoX448:
// algorithm is added in plaintext below
keyBlock = key
}
switch pub.PubKeyAlgo { switch pub.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly: case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly:
return serializeEncryptedKeyRSA(w, config.Random(), buf, pub.PublicKey.(*rsa.PublicKey), keyBlock) return serializeEncryptedKeyRSA(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*rsa.PublicKey), keyBlock)
case PubKeyAlgoElGamal: case PubKeyAlgoElGamal:
return serializeEncryptedKeyElGamal(w, config.Random(), buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock) return serializeEncryptedKeyElGamal(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*elgamal.PublicKey), keyBlock)
case PubKeyAlgoECDH: case PubKeyAlgoECDH:
return serializeEncryptedKeyECDH(w, config.Random(), buf, pub.PublicKey.(*ecdh.PublicKey), keyBlock, pub.oid, pub.Fingerprint) return serializeEncryptedKeyECDH(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*ecdh.PublicKey), keyBlock, pub.oid, pub.Fingerprint)
case PubKeyAlgoX25519:
return serializeEncryptedKeyX25519(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*x25519.PublicKey), keyBlock, byte(cipherFunc), version)
case PubKeyAlgoX448:
return serializeEncryptedKeyX448(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*x448.PublicKey), keyBlock, byte(cipherFunc), version)
case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly: case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly:
return errors.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo))) return errors.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
} }
@@ -211,14 +424,32 @@ func SerializeEncryptedKey(w io.Writer, pub *PublicKey, cipherFunc CipherFunctio
return errors.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo))) return errors.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo)))
} }
func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) error { // SerializeEncryptedKey serializes an encrypted key packet to w that contains
// key, encrypted to pub.
// PKESKv6 is used if config.AEAD() is not nil.
// If config is nil, sensible defaults will be used.
// Deprecated: Use SerializeEncryptedKeyAEAD instead.
func SerializeEncryptedKey(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, key []byte, config *Config) error {
return SerializeEncryptedKeyAEAD(w, pub, cipherFunc, config.AEAD() != nil, key, config)
}
// SerializeEncryptedKeyWithHiddenOption serializes an encrypted key packet to w that contains
// key, encrypted to pub. PKESKv6 is used if config.AEAD() is not nil.
// The hidden option controls if the packet should be anonymous, i.e., omit key metadata.
// If config is nil, sensible defaults will be used.
// Deprecated: Use SerializeEncryptedKeyAEADwithHiddenOption instead.
func SerializeEncryptedKeyWithHiddenOption(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, key []byte, hidden bool, config *Config) error {
return SerializeEncryptedKeyAEADwithHiddenOption(w, pub, cipherFunc, config.AEAD() != nil, key, hidden, config)
}
func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header []byte, pub *rsa.PublicKey, keyBlock []byte) error {
cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock) cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock)
if err != nil { if err != nil {
return errors.InvalidArgumentError("RSA encryption failed: " + err.Error()) return errors.InvalidArgumentError("RSA encryption failed: " + err.Error())
} }
cipherMPI := encoding.NewMPI(cipherText) cipherMPI := encoding.NewMPI(cipherText)
packetLen := 10 /* header length */ + int(cipherMPI.EncodedLength()) packetLen := len(header) /* header length */ + int(cipherMPI.EncodedLength())
err = serializeHeader(w, packetTypeEncryptedKey, packetLen) err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
if err != nil { if err != nil {
@@ -232,13 +463,13 @@ func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub
return err return err
} }
func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) error { func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header []byte, pub *elgamal.PublicKey, keyBlock []byte) error {
c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock) c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock)
if err != nil { if err != nil {
return errors.InvalidArgumentError("ElGamal encryption failed: " + err.Error()) return errors.InvalidArgumentError("ElGamal encryption failed: " + err.Error())
} }
packetLen := 10 /* header length */ packetLen := len(header) /* header length */
packetLen += 2 /* mpi size */ + (c1.BitLen()+7)/8 packetLen += 2 /* mpi size */ + (c1.BitLen()+7)/8
packetLen += 2 /* mpi size */ + (c2.BitLen()+7)/8 packetLen += 2 /* mpi size */ + (c2.BitLen()+7)/8
@@ -257,7 +488,7 @@ func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte,
return err return err
} }
func serializeEncryptedKeyECDH(w io.Writer, rand io.Reader, header [10]byte, pub *ecdh.PublicKey, keyBlock []byte, oid encoding.Field, fingerprint []byte) error { func serializeEncryptedKeyECDH(w io.Writer, rand io.Reader, header []byte, pub *ecdh.PublicKey, keyBlock []byte, oid encoding.Field, fingerprint []byte) error {
vsG, c, err := ecdh.Encrypt(rand, pub, keyBlock, oid.EncodedBytes(), fingerprint) vsG, c, err := ecdh.Encrypt(rand, pub, keyBlock, oid.EncodedBytes(), fingerprint)
if err != nil { if err != nil {
return errors.InvalidArgumentError("ECDH encryption failed: " + err.Error()) return errors.InvalidArgumentError("ECDH encryption failed: " + err.Error())
@@ -266,7 +497,7 @@ func serializeEncryptedKeyECDH(w io.Writer, rand io.Reader, header [10]byte, pub
g := encoding.NewMPI(vsG) g := encoding.NewMPI(vsG)
m := encoding.NewOID(c) m := encoding.NewOID(c)
packetLen := 10 /* header length */ packetLen := len(header) /* header length */
packetLen += int(g.EncodedLength()) + int(m.EncodedLength()) packetLen += int(g.EncodedLength()) + int(m.EncodedLength())
err = serializeHeader(w, packetTypeEncryptedKey, packetLen) err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
@@ -284,3 +515,70 @@ func serializeEncryptedKeyECDH(w io.Writer, rand io.Reader, header [10]byte, pub
_, err = w.Write(m.EncodedBytes()) _, err = w.Write(m.EncodedBytes())
return err return err
} }
func serializeEncryptedKeyX25519(w io.Writer, rand io.Reader, header []byte, pub *x25519.PublicKey, keyBlock []byte, cipherFunc byte, version int) error {
ephemeralPublicX25519, ciphertext, err := x25519.Encrypt(rand, pub, keyBlock)
if err != nil {
return errors.InvalidArgumentError("x25519 encryption failed: " + err.Error())
}
packetLen := len(header) /* header length */
packetLen += x25519.EncodedFieldsLength(ciphertext, version == 6)
err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
if err != nil {
return err
}
_, err = w.Write(header[:])
if err != nil {
return err
}
return x25519.EncodeFields(w, ephemeralPublicX25519, ciphertext, cipherFunc, version == 6)
}
func serializeEncryptedKeyX448(w io.Writer, rand io.Reader, header []byte, pub *x448.PublicKey, keyBlock []byte, cipherFunc byte, version int) error {
ephemeralPublicX448, ciphertext, err := x448.Encrypt(rand, pub, keyBlock)
if err != nil {
return errors.InvalidArgumentError("x448 encryption failed: " + err.Error())
}
packetLen := len(header) /* header length */
packetLen += x448.EncodedFieldsLength(ciphertext, version == 6)
err = serializeHeader(w, packetTypeEncryptedKey, packetLen)
if err != nil {
return err
}
_, err = w.Write(header[:])
if err != nil {
return err
}
return x448.EncodeFields(w, ephemeralPublicX448, ciphertext, cipherFunc, version == 6)
}
func checksumKeyMaterial(key []byte) uint16 {
var checksum uint16
for _, v := range key {
checksum += uint16(v)
}
return checksum
}
func decodeChecksumKey(msg []byte) (key []byte, err error) {
key = msg[:len(msg)-2]
expectedChecksum := uint16(msg[len(msg)-2])<<8 | uint16(msg[len(msg)-1])
checksum := checksumKeyMaterial(key)
if checksum != expectedChecksum {
err = errors.StructuralError("session key checksum is incorrect")
}
return
}
func encodeChecksumKey(buffer []byte, key []byte) {
copy(buffer, key)
checksum := checksumKeyMaterial(key)
buffer[len(key)] = byte(checksum >> 8)
buffer[len(key)+1] = byte(checksum)
}

View File

@@ -58,9 +58,9 @@ func (l *LiteralData) parse(r io.Reader) (err error) {
// on completion. The fileName is truncated to 255 bytes. // on completion. The fileName is truncated to 255 bytes.
func SerializeLiteral(w io.WriteCloser, isBinary bool, fileName string, time uint32) (plaintext io.WriteCloser, err error) { func SerializeLiteral(w io.WriteCloser, isBinary bool, fileName string, time uint32) (plaintext io.WriteCloser, err error) {
var buf [4]byte var buf [4]byte
buf[0] = 't' buf[0] = 'b'
if isBinary { if !isBinary {
buf[0] = 'b' buf[0] = 'u'
} }
if len(fileName) > 255 { if len(fileName) > 255 {
fileName = fileName[:255] fileName = fileName[:255]

View File

@@ -0,0 +1,33 @@
package packet
import (
"io"
"github.com/ProtonMail/go-crypto/openpgp/errors"
)
type Marker struct{}
const markerString = "PGP"
// parse just checks if the packet contains "PGP".
func (m *Marker) parse(reader io.Reader) error {
var buffer [3]byte
if _, err := io.ReadFull(reader, buffer[:]); err != nil {
return err
}
if string(buffer[:]) != markerString {
return errors.StructuralError("invalid marker packet")
}
return nil
}
// SerializeMarker writes a marker packet to writer.
func SerializeMarker(writer io.Writer) error {
err := serializeHeader(writer, packetTypeMarker, len(markerString))
if err != nil {
return err
}
_, err = writer.Write([]byte(markerString))
return err
}

View File

@@ -7,34 +7,37 @@ package packet
import ( import (
"crypto" "crypto"
"encoding/binary" "encoding/binary"
"github.com/ProtonMail/go-crypto/openpgp/errors"
"github.com/ProtonMail/go-crypto/openpgp/internal/algorithm"
"io" "io"
"strconv" "strconv"
"github.com/ProtonMail/go-crypto/openpgp/errors"
"github.com/ProtonMail/go-crypto/openpgp/internal/algorithm"
) )
// OnePassSignature represents a one-pass signature packet. See RFC 4880, // OnePassSignature represents a one-pass signature packet. See RFC 4880,
// section 5.4. // section 5.4.
type OnePassSignature struct { type OnePassSignature struct {
SigType SignatureType Version int
Hash crypto.Hash SigType SignatureType
PubKeyAlgo PublicKeyAlgorithm Hash crypto.Hash
KeyId uint64 PubKeyAlgo PublicKeyAlgorithm
IsLast bool KeyId uint64
IsLast bool
Salt []byte // v6 only
KeyFingerprint []byte // v6 only
} }
const onePassSignatureVersion = 3
func (ops *OnePassSignature) parse(r io.Reader) (err error) { func (ops *OnePassSignature) parse(r io.Reader) (err error) {
var buf [13]byte var buf [8]byte
// Read: version | signature type | hash algorithm | public-key algorithm
_, err = readFull(r, buf[:]) _, err = readFull(r, buf[:4])
if err != nil { if err != nil {
return return
} }
if buf[0] != onePassSignatureVersion { if buf[0] != 3 && buf[0] != 6 {
err = errors.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0]))) return errors.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0])))
} }
ops.Version = int(buf[0])
var ok bool var ok bool
ops.Hash, ok = algorithm.HashIdToHashWithSha1(buf[2]) ops.Hash, ok = algorithm.HashIdToHashWithSha1(buf[2])
@@ -44,15 +47,69 @@ func (ops *OnePassSignature) parse(r io.Reader) (err error) {
ops.SigType = SignatureType(buf[1]) ops.SigType = SignatureType(buf[1])
ops.PubKeyAlgo = PublicKeyAlgorithm(buf[3]) ops.PubKeyAlgo = PublicKeyAlgorithm(buf[3])
ops.KeyId = binary.BigEndian.Uint64(buf[4:12])
ops.IsLast = buf[12] != 0 if ops.Version == 6 {
// Only for v6, a variable-length field containing the salt
_, err = readFull(r, buf[:1])
if err != nil {
return
}
saltLength := int(buf[0])
var expectedSaltLength int
expectedSaltLength, err = SaltLengthForHash(ops.Hash)
if err != nil {
return
}
if saltLength != expectedSaltLength {
err = errors.StructuralError("unexpected salt size for the given hash algorithm")
return
}
salt := make([]byte, expectedSaltLength)
_, err = readFull(r, salt)
if err != nil {
return
}
ops.Salt = salt
// Only for v6 packets, 32 octets of the fingerprint of the signing key.
fingerprint := make([]byte, 32)
_, err = readFull(r, fingerprint)
if err != nil {
return
}
ops.KeyFingerprint = fingerprint
ops.KeyId = binary.BigEndian.Uint64(ops.KeyFingerprint[:8])
} else {
_, err = readFull(r, buf[:8])
if err != nil {
return
}
ops.KeyId = binary.BigEndian.Uint64(buf[:8])
}
_, err = readFull(r, buf[:1])
if err != nil {
return
}
ops.IsLast = buf[0] != 0
return return
} }
// Serialize marshals the given OnePassSignature to w. // Serialize marshals the given OnePassSignature to w.
func (ops *OnePassSignature) Serialize(w io.Writer) error { func (ops *OnePassSignature) Serialize(w io.Writer) error {
var buf [13]byte //v3 length 1+1+1+1+8+1 =
buf[0] = onePassSignatureVersion packetLength := 13
if ops.Version == 6 {
// v6 length 1+1+1+1+1+len(salt)+32+1 =
packetLength = 38 + len(ops.Salt)
}
if err := serializeHeader(w, packetTypeOnePassSignature, packetLength); err != nil {
return err
}
var buf [8]byte
buf[0] = byte(ops.Version)
buf[1] = uint8(ops.SigType) buf[1] = uint8(ops.SigType)
var ok bool var ok bool
buf[2], ok = algorithm.HashToHashIdWithSha1(ops.Hash) buf[2], ok = algorithm.HashToHashIdWithSha1(ops.Hash)
@@ -60,14 +117,41 @@ func (ops *OnePassSignature) Serialize(w io.Writer) error {
return errors.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash))) return errors.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash)))
} }
buf[3] = uint8(ops.PubKeyAlgo) buf[3] = uint8(ops.PubKeyAlgo)
binary.BigEndian.PutUint64(buf[4:12], ops.KeyId)
if ops.IsLast {
buf[12] = 1
}
if err := serializeHeader(w, packetTypeOnePassSignature, len(buf)); err != nil { _, err := w.Write(buf[:4])
if err != nil {
return err return err
} }
_, err := w.Write(buf[:])
if ops.Version == 6 {
// write salt for v6 signatures
_, err := w.Write([]byte{uint8(len(ops.Salt))})
if err != nil {
return err
}
_, err = w.Write(ops.Salt)
if err != nil {
return err
}
// write fingerprint v6 signatures
_, err = w.Write(ops.KeyFingerprint)
if err != nil {
return err
}
} else {
binary.BigEndian.PutUint64(buf[:8], ops.KeyId)
_, err := w.Write(buf[:8])
if err != nil {
return err
}
}
isLast := []byte{byte(0)}
if ops.IsLast {
isLast[0] = 1
}
_, err = w.Write(isLast)
return err return err
} }

View File

@@ -7,7 +7,6 @@ package packet
import ( import (
"bytes" "bytes"
"io" "io"
"io/ioutil"
"github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/errors"
) )
@@ -26,7 +25,7 @@ type OpaquePacket struct {
} }
func (op *OpaquePacket) parse(r io.Reader) (err error) { func (op *OpaquePacket) parse(r io.Reader) (err error) {
op.Contents, err = ioutil.ReadAll(r) op.Contents, err = io.ReadAll(r)
return return
} }

View File

@@ -311,12 +311,15 @@ const (
packetTypePrivateSubkey packetType = 7 packetTypePrivateSubkey packetType = 7
packetTypeCompressed packetType = 8 packetTypeCompressed packetType = 8
packetTypeSymmetricallyEncrypted packetType = 9 packetTypeSymmetricallyEncrypted packetType = 9
packetTypeMarker packetType = 10
packetTypeLiteralData packetType = 11 packetTypeLiteralData packetType = 11
packetTypeTrust packetType = 12
packetTypeUserId packetType = 13 packetTypeUserId packetType = 13
packetTypePublicSubkey packetType = 14 packetTypePublicSubkey packetType = 14
packetTypeUserAttribute packetType = 17 packetTypeUserAttribute packetType = 17
packetTypeSymmetricallyEncryptedIntegrityProtected packetType = 18 packetTypeSymmetricallyEncryptedIntegrityProtected packetType = 18
packetTypeAEADEncrypted packetType = 20 packetTypeAEADEncrypted packetType = 20
packetPadding packetType = 21
) )
// EncryptedDataPacket holds encrypted data. It is currently implemented by // EncryptedDataPacket holds encrypted data. It is currently implemented by
@@ -328,7 +331,7 @@ type EncryptedDataPacket interface {
// Read reads a single OpenPGP packet from the given io.Reader. If there is an // Read reads a single OpenPGP packet from the given io.Reader. If there is an
// error parsing a packet, the whole packet is consumed from the input. // error parsing a packet, the whole packet is consumed from the input.
func Read(r io.Reader) (p Packet, err error) { func Read(r io.Reader) (p Packet, err error) {
tag, _, contents, err := readHeader(r) tag, len, contents, err := readHeader(r)
if err != nil { if err != nil {
return return
} }
@@ -367,8 +370,93 @@ func Read(r io.Reader) (p Packet, err error) {
p = se p = se
case packetTypeAEADEncrypted: case packetTypeAEADEncrypted:
p = new(AEADEncrypted) p = new(AEADEncrypted)
default: case packetPadding:
p = Padding(len)
case packetTypeMarker:
p = new(Marker)
case packetTypeTrust:
// Not implemented, just consume
err = errors.UnknownPacketTypeError(tag) err = errors.UnknownPacketTypeError(tag)
default:
// Packet Tags from 0 to 39 are critical.
// Packet Tags from 40 to 63 are non-critical.
if tag < 40 {
err = errors.CriticalUnknownPacketTypeError(tag)
} else {
err = errors.UnknownPacketTypeError(tag)
}
}
if p != nil {
err = p.parse(contents)
}
if err != nil {
consumeAll(contents)
}
return
}
// ReadWithCheck reads a single OpenPGP message packet from the given io.Reader. If there is an
// error parsing a packet, the whole packet is consumed from the input.
// ReadWithCheck additionally checks if the OpenPGP message packet sequence adheres
// to the packet composition rules in rfc4880, if not throws an error.
func ReadWithCheck(r io.Reader, sequence *SequenceVerifier) (p Packet, msgErr error, err error) {
tag, len, contents, err := readHeader(r)
if err != nil {
return
}
switch tag {
case packetTypeEncryptedKey:
msgErr = sequence.Next(ESKSymbol)
p = new(EncryptedKey)
case packetTypeSignature:
msgErr = sequence.Next(SigSymbol)
p = new(Signature)
case packetTypeSymmetricKeyEncrypted:
msgErr = sequence.Next(ESKSymbol)
p = new(SymmetricKeyEncrypted)
case packetTypeOnePassSignature:
msgErr = sequence.Next(OPSSymbol)
p = new(OnePassSignature)
case packetTypeCompressed:
msgErr = sequence.Next(CompSymbol)
p = new(Compressed)
case packetTypeSymmetricallyEncrypted:
msgErr = sequence.Next(EncSymbol)
p = new(SymmetricallyEncrypted)
case packetTypeLiteralData:
msgErr = sequence.Next(LDSymbol)
p = new(LiteralData)
case packetTypeSymmetricallyEncryptedIntegrityProtected:
msgErr = sequence.Next(EncSymbol)
se := new(SymmetricallyEncrypted)
se.IntegrityProtected = true
p = se
case packetTypeAEADEncrypted:
msgErr = sequence.Next(EncSymbol)
p = new(AEADEncrypted)
case packetPadding:
p = Padding(len)
case packetTypeMarker:
p = new(Marker)
case packetTypeTrust:
// Not implemented, just consume
err = errors.UnknownPacketTypeError(tag)
case packetTypePrivateKey,
packetTypePrivateSubkey,
packetTypePublicKey,
packetTypePublicSubkey,
packetTypeUserId,
packetTypeUserAttribute:
msgErr = sequence.Next(UnknownSymbol)
consumeAll(contents)
default:
// Packet Tags from 0 to 39 are critical.
// Packet Tags from 40 to 63 are non-critical.
if tag < 40 {
err = errors.CriticalUnknownPacketTypeError(tag)
} else {
err = errors.UnknownPacketTypeError(tag)
}
} }
if p != nil { if p != nil {
err = p.parse(contents) err = p.parse(contents)
@@ -385,17 +473,17 @@ type SignatureType uint8
const ( const (
SigTypeBinary SignatureType = 0x00 SigTypeBinary SignatureType = 0x00
SigTypeText = 0x01 SigTypeText SignatureType = 0x01
SigTypeGenericCert = 0x10 SigTypeGenericCert SignatureType = 0x10
SigTypePersonaCert = 0x11 SigTypePersonaCert SignatureType = 0x11
SigTypeCasualCert = 0x12 SigTypeCasualCert SignatureType = 0x12
SigTypePositiveCert = 0x13 SigTypePositiveCert SignatureType = 0x13
SigTypeSubkeyBinding = 0x18 SigTypeSubkeyBinding SignatureType = 0x18
SigTypePrimaryKeyBinding = 0x19 SigTypePrimaryKeyBinding SignatureType = 0x19
SigTypeDirectSignature = 0x1F SigTypeDirectSignature SignatureType = 0x1F
SigTypeKeyRevocation = 0x20 SigTypeKeyRevocation SignatureType = 0x20
SigTypeSubkeyRevocation = 0x28 SigTypeSubkeyRevocation SignatureType = 0x28
SigTypeCertificationRevocation = 0x30 SigTypeCertificationRevocation SignatureType = 0x30
) )
// PublicKeyAlgorithm represents the different public key system specified for // PublicKeyAlgorithm represents the different public key system specified for
@@ -412,6 +500,11 @@ const (
PubKeyAlgoECDSA PublicKeyAlgorithm = 19 PubKeyAlgoECDSA PublicKeyAlgorithm = 19
// https://www.ietf.org/archive/id/draft-koch-eddsa-for-openpgp-04.txt // https://www.ietf.org/archive/id/draft-koch-eddsa-for-openpgp-04.txt
PubKeyAlgoEdDSA PublicKeyAlgorithm = 22 PubKeyAlgoEdDSA PublicKeyAlgorithm = 22
// https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh
PubKeyAlgoX25519 PublicKeyAlgorithm = 25
PubKeyAlgoX448 PublicKeyAlgorithm = 26
PubKeyAlgoEd25519 PublicKeyAlgorithm = 27
PubKeyAlgoEd448 PublicKeyAlgorithm = 28
// Deprecated in RFC 4880, Section 13.5. Use key flags instead. // Deprecated in RFC 4880, Section 13.5. Use key flags instead.
PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2 PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2
@@ -422,7 +515,7 @@ const (
// key of the given type. // key of the given type.
func (pka PublicKeyAlgorithm) CanEncrypt() bool { func (pka PublicKeyAlgorithm) CanEncrypt() bool {
switch pka { switch pka {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH: case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH, PubKeyAlgoX25519, PubKeyAlgoX448:
return true return true
} }
return false return false
@@ -432,7 +525,7 @@ func (pka PublicKeyAlgorithm) CanEncrypt() bool {
// sign a message. // sign a message.
func (pka PublicKeyAlgorithm) CanSign() bool { func (pka PublicKeyAlgorithm) CanSign() bool {
switch pka { switch pka {
case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA: case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA, PubKeyAlgoEd25519, PubKeyAlgoEd448:
return true return true
} }
return false return false
@@ -512,6 +605,11 @@ func (mode AEADMode) TagLength() int {
return algorithm.AEADMode(mode).TagLength() return algorithm.AEADMode(mode).TagLength()
} }
// IsSupported returns true if the aead mode is supported from the library
func (mode AEADMode) IsSupported() bool {
return algorithm.AEADMode(mode).TagLength() > 0
}
// new returns a fresh instance of the given mode. // new returns a fresh instance of the given mode.
func (mode AEADMode) new(block cipher.Block) cipher.AEAD { func (mode AEADMode) new(block cipher.Block) cipher.AEAD {
return algorithm.AEADMode(mode).New(block) return algorithm.AEADMode(mode).New(block)
@@ -526,8 +624,17 @@ const (
KeySuperseded ReasonForRevocation = 1 KeySuperseded ReasonForRevocation = 1
KeyCompromised ReasonForRevocation = 2 KeyCompromised ReasonForRevocation = 2
KeyRetired ReasonForRevocation = 3 KeyRetired ReasonForRevocation = 3
UserIDNotValid ReasonForRevocation = 32
Unknown ReasonForRevocation = 200
) )
func NewReasonForRevocation(value byte) ReasonForRevocation {
if value < 4 || value == 32 {
return ReasonForRevocation(value)
}
return Unknown
}
// Curve is a mapping to supported ECC curves for key generation. // Curve is a mapping to supported ECC curves for key generation.
// See https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-06.html#name-curve-specific-wire-formats // See https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-06.html#name-curve-specific-wire-formats
type Curve string type Curve string
@@ -549,3 +656,20 @@ type TrustLevel uint8
// TrustAmount represents a trust amount per RFC4880 5.2.3.13 // TrustAmount represents a trust amount per RFC4880 5.2.3.13
type TrustAmount uint8 type TrustAmount uint8
const (
// versionSize is the length in bytes of the version value.
versionSize = 1
// algorithmSize is the length in bytes of the key algorithm value.
algorithmSize = 1
// keyVersionSize is the length in bytes of the key version value
keyVersionSize = 1
// keyIdSize is the length in bytes of the key identifier value.
keyIdSize = 8
// timestampSize is the length in bytes of encoded timestamps.
timestampSize = 4
// fingerprintSizeV6 is the length in bytes of the key fingerprint in v6.
fingerprintSizeV6 = 32
// fingerprintSize is the length in bytes of the key fingerprint.
fingerprintSize = 20
)

View File

@@ -0,0 +1,222 @@
package packet
// This file implements the pushdown automata (PDA) from PGPainless (Paul Schaub)
// to verify pgp packet sequences. See Paul's blogpost for more details:
// https://blog.jabberhead.tk/2022/10/26/implementing-packet-sequence-validation-using-pushdown-automata/
import (
"fmt"
"github.com/ProtonMail/go-crypto/openpgp/errors"
)
func NewErrMalformedMessage(from State, input InputSymbol, stackSymbol StackSymbol) errors.ErrMalformedMessage {
return errors.ErrMalformedMessage(fmt.Sprintf("state %d, input symbol %d, stack symbol %d ", from, input, stackSymbol))
}
// InputSymbol defines the input alphabet of the PDA
type InputSymbol uint8
const (
LDSymbol InputSymbol = iota
SigSymbol
OPSSymbol
CompSymbol
ESKSymbol
EncSymbol
EOSSymbol
UnknownSymbol
)
// StackSymbol defines the stack alphabet of the PDA
type StackSymbol int8
const (
MsgStackSymbol StackSymbol = iota
OpsStackSymbol
KeyStackSymbol
EndStackSymbol
EmptyStackSymbol
)
// State defines the states of the PDA
type State int8
const (
OpenPGPMessage State = iota
ESKMessage
LiteralMessage
CompressedMessage
EncryptedMessage
ValidMessage
)
// transition represents a state transition in the PDA
type transition func(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error)
// SequenceVerifier is a pushdown automata to verify
// PGP messages packet sequences according to rfc4880.
type SequenceVerifier struct {
stack []StackSymbol
state State
}
// Next performs a state transition with the given input symbol.
// If the transition fails a ErrMalformedMessage is returned.
func (sv *SequenceVerifier) Next(input InputSymbol) error {
for {
stackSymbol := sv.popStack()
transitionFunc := getTransition(sv.state)
nextState, newStackSymbols, redo, err := transitionFunc(input, stackSymbol)
if err != nil {
return err
}
if redo {
sv.pushStack(stackSymbol)
}
for _, newStackSymbol := range newStackSymbols {
sv.pushStack(newStackSymbol)
}
sv.state = nextState
if !redo {
break
}
}
return nil
}
// Valid returns true if RDA is in a valid state.
func (sv *SequenceVerifier) Valid() bool {
return sv.state == ValidMessage && len(sv.stack) == 0
}
func (sv *SequenceVerifier) AssertValid() error {
if !sv.Valid() {
return errors.ErrMalformedMessage("invalid message")
}
return nil
}
func NewSequenceVerifier() *SequenceVerifier {
return &SequenceVerifier{
stack: []StackSymbol{EndStackSymbol, MsgStackSymbol},
state: OpenPGPMessage,
}
}
func (sv *SequenceVerifier) popStack() StackSymbol {
if len(sv.stack) == 0 {
return EmptyStackSymbol
}
elemIndex := len(sv.stack) - 1
stackSymbol := sv.stack[elemIndex]
sv.stack = sv.stack[:elemIndex]
return stackSymbol
}
func (sv *SequenceVerifier) pushStack(stackSymbol StackSymbol) {
sv.stack = append(sv.stack, stackSymbol)
}
func getTransition(from State) transition {
switch from {
case OpenPGPMessage:
return fromOpenPGPMessage
case LiteralMessage:
return fromLiteralMessage
case CompressedMessage:
return fromCompressedMessage
case EncryptedMessage:
return fromEncryptedMessage
case ESKMessage:
return fromESKMessage
case ValidMessage:
return fromValidMessage
}
return nil
}
// fromOpenPGPMessage is the transition for the state OpenPGPMessage.
func fromOpenPGPMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
if stackSymbol != MsgStackSymbol {
return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol)
}
switch input {
case LDSymbol:
return LiteralMessage, nil, false, nil
case SigSymbol:
return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, false, nil
case OPSSymbol:
return OpenPGPMessage, []StackSymbol{OpsStackSymbol, MsgStackSymbol}, false, nil
case CompSymbol:
return CompressedMessage, nil, false, nil
case ESKSymbol:
return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil
case EncSymbol:
return EncryptedMessage, nil, false, nil
}
return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol)
}
// fromESKMessage is the transition for the state ESKMessage.
func fromESKMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
if stackSymbol != KeyStackSymbol {
return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol)
}
switch input {
case ESKSymbol:
return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil
case EncSymbol:
return EncryptedMessage, nil, false, nil
}
return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol)
}
// fromLiteralMessage is the transition for the state LiteralMessage.
func fromLiteralMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
switch input {
case SigSymbol:
if stackSymbol == OpsStackSymbol {
return LiteralMessage, nil, false, nil
}
case EOSSymbol:
if stackSymbol == EndStackSymbol {
return ValidMessage, nil, false, nil
}
}
return 0, nil, false, NewErrMalformedMessage(LiteralMessage, input, stackSymbol)
}
// fromLiteralMessage is the transition for the state CompressedMessage.
func fromCompressedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
switch input {
case SigSymbol:
if stackSymbol == OpsStackSymbol {
return CompressedMessage, nil, false, nil
}
case EOSSymbol:
if stackSymbol == EndStackSymbol {
return ValidMessage, nil, false, nil
}
}
return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil
}
// fromEncryptedMessage is the transition for the state EncryptedMessage.
func fromEncryptedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
switch input {
case SigSymbol:
if stackSymbol == OpsStackSymbol {
return EncryptedMessage, nil, false, nil
}
case EOSSymbol:
if stackSymbol == EndStackSymbol {
return ValidMessage, nil, false, nil
}
}
return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil
}
// fromValidMessage is the transition for the state ValidMessage.
func fromValidMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) {
return 0, nil, false, NewErrMalformedMessage(ValidMessage, input, stackSymbol)
}

View File

@@ -0,0 +1,24 @@
package packet
import (
"io"
"github.com/ProtonMail/go-crypto/openpgp/errors"
)
// UnsupportedPackage represents a OpenPGP packet with a known packet type
// but with unsupported content.
type UnsupportedPacket struct {
IncompletePacket Packet
Error errors.UnsupportedError
}
// Implements the Packet interface
func (up *UnsupportedPacket) parse(read io.Reader) error {
err := up.IncompletePacket.parse(read)
if castedErr, ok := err.(errors.UnsupportedError); ok {
up.Error = castedErr
return nil
}
return err
}

View File

@@ -0,0 +1,26 @@
package packet
import (
"io"
)
// Padding type represents a Padding Packet (Tag 21).
// The padding type is represented by the length of its padding.
// see https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh#name-padding-packet-tag-21
type Padding int
// parse just ignores the padding content.
func (pad Padding) parse(reader io.Reader) error {
_, err := io.CopyN(io.Discard, reader, int64(pad))
return err
}
// SerializePadding writes the padding to writer.
func (pad Padding) SerializePadding(writer io.Writer, rand io.Reader) error {
err := serializeHeader(writer, packetPadding, int(pad))
if err != nil {
return err
}
_, err = io.CopyN(writer, rand, int64(pad))
return err
}

View File

@@ -9,22 +9,28 @@ import (
"crypto" "crypto"
"crypto/cipher" "crypto/cipher"
"crypto/dsa" "crypto/dsa"
"crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"crypto/sha256"
"crypto/subtle"
"fmt"
"io" "io"
"io/ioutil"
"math/big" "math/big"
"strconv" "strconv"
"time" "time"
"github.com/ProtonMail/go-crypto/openpgp/ecdh" "github.com/ProtonMail/go-crypto/openpgp/ecdh"
"github.com/ProtonMail/go-crypto/openpgp/ecdsa" "github.com/ProtonMail/go-crypto/openpgp/ecdsa"
"github.com/ProtonMail/go-crypto/openpgp/ed25519"
"github.com/ProtonMail/go-crypto/openpgp/ed448"
"github.com/ProtonMail/go-crypto/openpgp/eddsa" "github.com/ProtonMail/go-crypto/openpgp/eddsa"
"github.com/ProtonMail/go-crypto/openpgp/elgamal" "github.com/ProtonMail/go-crypto/openpgp/elgamal"
"github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/errors"
"github.com/ProtonMail/go-crypto/openpgp/internal/encoding" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding"
"github.com/ProtonMail/go-crypto/openpgp/s2k" "github.com/ProtonMail/go-crypto/openpgp/s2k"
"github.com/ProtonMail/go-crypto/openpgp/x25519"
"github.com/ProtonMail/go-crypto/openpgp/x448"
"golang.org/x/crypto/hkdf"
) )
// PrivateKey represents a possibly encrypted private key. See RFC 4880, // PrivateKey represents a possibly encrypted private key. See RFC 4880,
@@ -35,14 +41,14 @@ type PrivateKey struct {
encryptedData []byte encryptedData []byte
cipher CipherFunction cipher CipherFunction
s2k func(out, in []byte) s2k func(out, in []byte)
// An *{rsa|dsa|elgamal|ecdh|ecdsa|ed25519}.PrivateKey or aead AEADMode // only relevant if S2KAEAD is enabled
// An *{rsa|dsa|elgamal|ecdh|ecdsa|ed25519|ed448}.PrivateKey or
// crypto.Signer/crypto.Decrypter (Decryptor RSA only). // crypto.Signer/crypto.Decrypter (Decryptor RSA only).
PrivateKey interface{} PrivateKey interface{}
sha1Checksum bool iv []byte
iv []byte
// Type of encryption of the S2K packet // Type of encryption of the S2K packet
// Allowed values are 0 (Not encrypted), 254 (SHA1), or // Allowed values are 0 (Not encrypted), 253 (AEAD), 254 (SHA1), or
// 255 (2-byte checksum) // 255 (2-byte checksum)
s2kType S2KType s2kType S2KType
// Full parameters of the S2K packet // Full parameters of the S2K packet
@@ -55,6 +61,8 @@ type S2KType uint8
const ( const (
// S2KNON unencrypt // S2KNON unencrypt
S2KNON S2KType = 0 S2KNON S2KType = 0
// S2KAEAD use authenticated encryption
S2KAEAD S2KType = 253
// S2KSHA1 sha1 sum check // S2KSHA1 sha1 sum check
S2KSHA1 S2KType = 254 S2KSHA1 S2KType = 254
// S2KCHECKSUM sum check // S2KCHECKSUM sum check
@@ -103,6 +111,34 @@ func NewECDHPrivateKey(creationTime time.Time, priv *ecdh.PrivateKey) *PrivateKe
return pk return pk
} }
func NewX25519PrivateKey(creationTime time.Time, priv *x25519.PrivateKey) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewX25519PublicKey(creationTime, &priv.PublicKey)
pk.PrivateKey = priv
return pk
}
func NewX448PrivateKey(creationTime time.Time, priv *x448.PrivateKey) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewX448PublicKey(creationTime, &priv.PublicKey)
pk.PrivateKey = priv
return pk
}
func NewEd25519PrivateKey(creationTime time.Time, priv *ed25519.PrivateKey) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewEd25519PublicKey(creationTime, &priv.PublicKey)
pk.PrivateKey = priv
return pk
}
func NewEd448PrivateKey(creationTime time.Time, priv *ed448.PrivateKey) *PrivateKey {
pk := new(PrivateKey)
pk.PublicKey = *NewEd448PublicKey(creationTime, &priv.PublicKey)
pk.PrivateKey = priv
return pk
}
// NewSignerPrivateKey creates a PrivateKey from a crypto.Signer that // NewSignerPrivateKey creates a PrivateKey from a crypto.Signer that
// implements RSA, ECDSA or EdDSA. // implements RSA, ECDSA or EdDSA.
func NewSignerPrivateKey(creationTime time.Time, signer interface{}) *PrivateKey { func NewSignerPrivateKey(creationTime time.Time, signer interface{}) *PrivateKey {
@@ -122,6 +158,14 @@ func NewSignerPrivateKey(creationTime time.Time, signer interface{}) *PrivateKey
pk.PublicKey = *NewEdDSAPublicKey(creationTime, &pubkey.PublicKey) pk.PublicKey = *NewEdDSAPublicKey(creationTime, &pubkey.PublicKey)
case eddsa.PrivateKey: case eddsa.PrivateKey:
pk.PublicKey = *NewEdDSAPublicKey(creationTime, &pubkey.PublicKey) pk.PublicKey = *NewEdDSAPublicKey(creationTime, &pubkey.PublicKey)
case *ed25519.PrivateKey:
pk.PublicKey = *NewEd25519PublicKey(creationTime, &pubkey.PublicKey)
case ed25519.PrivateKey:
pk.PublicKey = *NewEd25519PublicKey(creationTime, &pubkey.PublicKey)
case *ed448.PrivateKey:
pk.PublicKey = *NewEd448PublicKey(creationTime, &pubkey.PublicKey)
case ed448.PrivateKey:
pk.PublicKey = *NewEd448PublicKey(creationTime, &pubkey.PublicKey)
default: default:
panic("openpgp: unknown signer type in NewSignerPrivateKey") panic("openpgp: unknown signer type in NewSignerPrivateKey")
} }
@@ -129,7 +173,7 @@ func NewSignerPrivateKey(creationTime time.Time, signer interface{}) *PrivateKey
return pk return pk
} }
// NewDecrypterPrivateKey creates a PrivateKey from a *{rsa|elgamal|ecdh}.PrivateKey. // NewDecrypterPrivateKey creates a PrivateKey from a *{rsa|elgamal|ecdh|x25519|x448}.PrivateKey.
func NewDecrypterPrivateKey(creationTime time.Time, decrypter interface{}) *PrivateKey { func NewDecrypterPrivateKey(creationTime time.Time, decrypter interface{}) *PrivateKey {
pk := new(PrivateKey) pk := new(PrivateKey)
switch priv := decrypter.(type) { switch priv := decrypter.(type) {
@@ -139,6 +183,10 @@ func NewDecrypterPrivateKey(creationTime time.Time, decrypter interface{}) *Priv
pk.PublicKey = *NewElGamalPublicKey(creationTime, &priv.PublicKey) pk.PublicKey = *NewElGamalPublicKey(creationTime, &priv.PublicKey)
case *ecdh.PrivateKey: case *ecdh.PrivateKey:
pk.PublicKey = *NewECDHPublicKey(creationTime, &priv.PublicKey) pk.PublicKey = *NewECDHPublicKey(creationTime, &priv.PublicKey)
case *x25519.PrivateKey:
pk.PublicKey = *NewX25519PublicKey(creationTime, &priv.PublicKey)
case *x448.PrivateKey:
pk.PublicKey = *NewX448PublicKey(creationTime, &priv.PublicKey)
default: default:
panic("openpgp: unknown decrypter type in NewDecrypterPrivateKey") panic("openpgp: unknown decrypter type in NewDecrypterPrivateKey")
} }
@@ -152,6 +200,11 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
return return
} }
v5 := pk.PublicKey.Version == 5 v5 := pk.PublicKey.Version == 5
v6 := pk.PublicKey.Version == 6
if V5Disabled && v5 {
return errors.UnsupportedError("support for parsing v5 entities is disabled; build with `-tags v5` if needed")
}
var buf [1]byte var buf [1]byte
_, err = readFull(r, buf[:]) _, err = readFull(r, buf[:])
@@ -160,7 +213,7 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
} }
pk.s2kType = S2KType(buf[0]) pk.s2kType = S2KType(buf[0])
var optCount [1]byte var optCount [1]byte
if v5 { if v5 || (v6 && pk.s2kType != S2KNON) {
if _, err = readFull(r, optCount[:]); err != nil { if _, err = readFull(r, optCount[:]); err != nil {
return return
} }
@@ -170,9 +223,9 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
case S2KNON: case S2KNON:
pk.s2k = nil pk.s2k = nil
pk.Encrypted = false pk.Encrypted = false
case S2KSHA1, S2KCHECKSUM: case S2KSHA1, S2KCHECKSUM, S2KAEAD:
if v5 && pk.s2kType == S2KCHECKSUM { if (v5 || v6) && pk.s2kType == S2KCHECKSUM {
return errors.StructuralError("wrong s2k identifier for version 5") return errors.StructuralError(fmt.Sprintf("wrong s2k identifier for version %d", pk.Version))
} }
_, err = readFull(r, buf[:]) _, err = readFull(r, buf[:])
if err != nil { if err != nil {
@@ -182,6 +235,29 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
if pk.cipher != 0 && !pk.cipher.IsSupported() { if pk.cipher != 0 && !pk.cipher.IsSupported() {
return errors.UnsupportedError("unsupported cipher function in private key") return errors.UnsupportedError("unsupported cipher function in private key")
} }
// [Optional] If string-to-key usage octet was 253,
// a one-octet AEAD algorithm.
if pk.s2kType == S2KAEAD {
_, err = readFull(r, buf[:])
if err != nil {
return
}
pk.aead = AEADMode(buf[0])
if !pk.aead.IsSupported() {
return errors.UnsupportedError("unsupported aead mode in private key")
}
}
// [Optional] Only for a version 6 packet,
// and if string-to-key usage octet was 255, 254, or 253,
// an one-octet count of the following field.
if v6 {
_, err = readFull(r, buf[:])
if err != nil {
return
}
}
pk.s2kParams, err = s2k.ParseIntoParams(r) pk.s2kParams, err = s2k.ParseIntoParams(r)
if err != nil { if err != nil {
return return
@@ -189,28 +265,43 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
if pk.s2kParams.Dummy() { if pk.s2kParams.Dummy() {
return return
} }
if pk.s2kParams.Mode() == s2k.Argon2S2K && pk.s2kType != S2KAEAD {
return errors.StructuralError("using Argon2 S2K without AEAD is not allowed")
}
if pk.s2kParams.Mode() == s2k.SimpleS2K && pk.Version == 6 {
return errors.StructuralError("using Simple S2K with version 6 keys is not allowed")
}
pk.s2k, err = pk.s2kParams.Function() pk.s2k, err = pk.s2kParams.Function()
if err != nil { if err != nil {
return return
} }
pk.Encrypted = true pk.Encrypted = true
if pk.s2kType == S2KSHA1 {
pk.sha1Checksum = true
}
default: default:
return errors.UnsupportedError("deprecated s2k function in private key") return errors.UnsupportedError("deprecated s2k function in private key")
} }
if pk.Encrypted { if pk.Encrypted {
blockSize := pk.cipher.blockSize() var ivSize int
if blockSize == 0 { // If the S2K usage octet was 253, the IV is of the size expected by the AEAD mode,
// unless it's a version 5 key, in which case it's the size of the symmetric cipher's block size.
// For all other S2K modes, it's always the block size.
if !v5 && pk.s2kType == S2KAEAD {
ivSize = pk.aead.IvLength()
} else {
ivSize = pk.cipher.blockSize()
}
if ivSize == 0 {
return errors.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher))) return errors.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher)))
} }
pk.iv = make([]byte, blockSize) pk.iv = make([]byte, ivSize)
_, err = readFull(r, pk.iv) _, err = readFull(r, pk.iv)
if err != nil { if err != nil {
return return
} }
if v5 && pk.s2kType == S2KAEAD {
pk.iv = pk.iv[:pk.aead.IvLength()]
}
} }
var privateKeyData []byte var privateKeyData []byte
@@ -230,7 +321,7 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
return return
} }
} else { } else {
privateKeyData, err = ioutil.ReadAll(r) privateKeyData, err = io.ReadAll(r)
if err != nil { if err != nil {
return return
} }
@@ -239,16 +330,22 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) {
if len(privateKeyData) < 2 { if len(privateKeyData) < 2 {
return errors.StructuralError("truncated private key data") return errors.StructuralError("truncated private key data")
} }
var sum uint16 if pk.Version != 6 {
for i := 0; i < len(privateKeyData)-2; i++ { // checksum
sum += uint16(privateKeyData[i]) var sum uint16
for i := 0; i < len(privateKeyData)-2; i++ {
sum += uint16(privateKeyData[i])
}
if privateKeyData[len(privateKeyData)-2] != uint8(sum>>8) ||
privateKeyData[len(privateKeyData)-1] != uint8(sum) {
return errors.StructuralError("private key checksum failure")
}
privateKeyData = privateKeyData[:len(privateKeyData)-2]
return pk.parsePrivateKey(privateKeyData)
} else {
// No checksum
return pk.parsePrivateKey(privateKeyData)
} }
if privateKeyData[len(privateKeyData)-2] != uint8(sum>>8) ||
privateKeyData[len(privateKeyData)-1] != uint8(sum) {
return errors.StructuralError("private key checksum failure")
}
privateKeyData = privateKeyData[:len(privateKeyData)-2]
return pk.parsePrivateKey(privateKeyData)
} }
pk.encryptedData = privateKeyData pk.encryptedData = privateKeyData
@@ -280,18 +377,59 @@ func (pk *PrivateKey) Serialize(w io.Writer) (err error) {
optional := bytes.NewBuffer(nil) optional := bytes.NewBuffer(nil)
if pk.Encrypted || pk.Dummy() { if pk.Encrypted || pk.Dummy() {
optional.Write([]byte{uint8(pk.cipher)}) // [Optional] If string-to-key usage octet was 255, 254, or 253,
if err := pk.s2kParams.Serialize(optional); err != nil { // a one-octet symmetric encryption algorithm.
if _, err = optional.Write([]byte{uint8(pk.cipher)}); err != nil {
return
}
// [Optional] If string-to-key usage octet was 253,
// a one-octet AEAD algorithm.
if pk.s2kType == S2KAEAD {
if _, err = optional.Write([]byte{uint8(pk.aead)}); err != nil {
return
}
}
s2kBuffer := bytes.NewBuffer(nil)
if err := pk.s2kParams.Serialize(s2kBuffer); err != nil {
return err return err
} }
// [Optional] Only for a version 6 packet, and if string-to-key
// usage octet was 255, 254, or 253, an one-octet
// count of the following field.
if pk.Version == 6 {
if _, err = optional.Write([]byte{uint8(s2kBuffer.Len())}); err != nil {
return
}
}
// [Optional] If string-to-key usage octet was 255, 254, or 253,
// a string-to-key (S2K) specifier. The length of the string-to-key specifier
// depends on its type
if _, err = io.Copy(optional, s2kBuffer); err != nil {
return
}
// IV
if pk.Encrypted { if pk.Encrypted {
optional.Write(pk.iv) if _, err = optional.Write(pk.iv); err != nil {
return
}
if pk.Version == 5 && pk.s2kType == S2KAEAD {
// Add padding for version 5
padding := make([]byte, pk.cipher.blockSize()-len(pk.iv))
if _, err = optional.Write(padding); err != nil {
return
}
}
} }
} }
if pk.Version == 5 { if pk.Version == 5 || (pk.Version == 6 && pk.s2kType != S2KNON) {
contents.Write([]byte{uint8(optional.Len())}) contents.Write([]byte{uint8(optional.Len())})
} }
io.Copy(contents, optional)
if _, err := io.Copy(contents, optional); err != nil {
return err
}
if !pk.Dummy() { if !pk.Dummy() {
l := 0 l := 0
@@ -303,8 +441,10 @@ func (pk *PrivateKey) Serialize(w io.Writer) (err error) {
return err return err
} }
l = buf.Len() l = buf.Len()
checksum := mod64kHash(buf.Bytes()) if pk.Version != 6 {
buf.Write([]byte{byte(checksum >> 8), byte(checksum)}) checksum := mod64kHash(buf.Bytes())
buf.Write([]byte{byte(checksum >> 8), byte(checksum)})
}
priv = buf.Bytes() priv = buf.Bytes()
} else { } else {
priv, l = pk.encryptedData, len(pk.encryptedData) priv, l = pk.encryptedData, len(pk.encryptedData)
@@ -370,6 +510,26 @@ func serializeECDHPrivateKey(w io.Writer, priv *ecdh.PrivateKey) error {
return err return err
} }
func serializeX25519PrivateKey(w io.Writer, priv *x25519.PrivateKey) error {
_, err := w.Write(priv.Secret)
return err
}
func serializeX448PrivateKey(w io.Writer, priv *x448.PrivateKey) error {
_, err := w.Write(priv.Secret)
return err
}
func serializeEd25519PrivateKey(w io.Writer, priv *ed25519.PrivateKey) error {
_, err := w.Write(priv.MarshalByteSecret())
return err
}
func serializeEd448PrivateKey(w io.Writer, priv *ed448.PrivateKey) error {
_, err := w.Write(priv.MarshalByteSecret())
return err
}
// decrypt decrypts an encrypted private key using a decryption key. // decrypt decrypts an encrypted private key using a decryption key.
func (pk *PrivateKey) decrypt(decryptionKey []byte) error { func (pk *PrivateKey) decrypt(decryptionKey []byte) error {
if pk.Dummy() { if pk.Dummy() {
@@ -378,37 +538,51 @@ func (pk *PrivateKey) decrypt(decryptionKey []byte) error {
if !pk.Encrypted { if !pk.Encrypted {
return nil return nil
} }
block := pk.cipher.new(decryptionKey) block := pk.cipher.new(decryptionKey)
cfb := cipher.NewCFBDecrypter(block, pk.iv) var data []byte
switch pk.s2kType {
data := make([]byte, len(pk.encryptedData)) case S2KAEAD:
cfb.XORKeyStream(data, pk.encryptedData) aead := pk.aead.new(block)
additionalData, err := pk.additionalData()
if pk.sha1Checksum { if err != nil {
if len(data) < sha1.Size { return err
return errors.StructuralError("truncated private key data")
} }
h := sha1.New() // Decrypt the encrypted key material with aead
h.Write(data[:len(data)-sha1.Size]) data, err = aead.Open(nil, pk.iv, pk.encryptedData, additionalData)
sum := h.Sum(nil) if err != nil {
if !bytes.Equal(sum, data[len(data)-sha1.Size:]) { return err
return errors.StructuralError("private key checksum failure")
} }
data = data[:len(data)-sha1.Size] case S2KSHA1, S2KCHECKSUM:
} else { cfb := cipher.NewCFBDecrypter(block, pk.iv)
if len(data) < 2 { data = make([]byte, len(pk.encryptedData))
return errors.StructuralError("truncated private key data") cfb.XORKeyStream(data, pk.encryptedData)
if pk.s2kType == S2KSHA1 {
if len(data) < sha1.Size {
return errors.StructuralError("truncated private key data")
}
h := sha1.New()
h.Write(data[:len(data)-sha1.Size])
sum := h.Sum(nil)
if !bytes.Equal(sum, data[len(data)-sha1.Size:]) {
return errors.StructuralError("private key checksum failure")
}
data = data[:len(data)-sha1.Size]
} else {
if len(data) < 2 {
return errors.StructuralError("truncated private key data")
}
var sum uint16
for i := 0; i < len(data)-2; i++ {
sum += uint16(data[i])
}
if data[len(data)-2] != uint8(sum>>8) ||
data[len(data)-1] != uint8(sum) {
return errors.StructuralError("private key checksum failure")
}
data = data[:len(data)-2]
} }
var sum uint16 default:
for i := 0; i < len(data)-2; i++ { return errors.InvalidArgumentError("invalid s2k type")
sum += uint16(data[i])
}
if data[len(data)-2] != uint8(sum>>8) ||
data[len(data)-1] != uint8(sum) {
return errors.StructuralError("private key checksum failure")
}
data = data[:len(data)-2]
} }
err := pk.parsePrivateKey(data) err := pk.parsePrivateKey(data)
@@ -424,7 +598,6 @@ func (pk *PrivateKey) decrypt(decryptionKey []byte) error {
pk.s2k = nil pk.s2k = nil
pk.Encrypted = false pk.Encrypted = false
pk.encryptedData = nil pk.encryptedData = nil
return nil return nil
} }
@@ -440,6 +613,9 @@ func (pk *PrivateKey) decryptWithCache(passphrase []byte, keyCache *s2k.Cache) e
if err != nil { if err != nil {
return err return err
} }
if pk.s2kType == S2KAEAD {
key = pk.applyHKDF(key)
}
return pk.decrypt(key) return pk.decrypt(key)
} }
@@ -454,6 +630,9 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error {
key := make([]byte, pk.cipher.KeySize()) key := make([]byte, pk.cipher.KeySize())
pk.s2k(key, passphrase) pk.s2k(key, passphrase)
if pk.s2kType == S2KAEAD {
key = pk.applyHKDF(key)
}
return pk.decrypt(key) return pk.decrypt(key)
} }
@@ -474,7 +653,7 @@ func DecryptPrivateKeys(keys []*PrivateKey, passphrase []byte) error {
} }
// encrypt encrypts an unencrypted private key. // encrypt encrypts an unencrypted private key.
func (pk *PrivateKey) encrypt(key []byte, params *s2k.Params, cipherFunction CipherFunction) error { func (pk *PrivateKey) encrypt(key []byte, params *s2k.Params, s2kType S2KType, cipherFunction CipherFunction, rand io.Reader) error {
if pk.Dummy() { if pk.Dummy() {
return errors.ErrDummyPrivateKey("dummy key found") return errors.ErrDummyPrivateKey("dummy key found")
} }
@@ -486,6 +665,14 @@ func (pk *PrivateKey) encrypt(key []byte, params *s2k.Params, cipherFunction Cip
return errors.InvalidArgumentError("supplied encryption key has the wrong size") return errors.InvalidArgumentError("supplied encryption key has the wrong size")
} }
if params.Mode() == s2k.Argon2S2K && s2kType != S2KAEAD {
return errors.InvalidArgumentError("using Argon2 S2K without AEAD is not allowed")
}
if params.Mode() != s2k.Argon2S2K && params.Mode() != s2k.IteratedSaltedS2K &&
params.Mode() != s2k.SaltedS2K { // only allowed for high-entropy passphrases
return errors.InvalidArgumentError("insecure S2K mode")
}
priv := bytes.NewBuffer(nil) priv := bytes.NewBuffer(nil)
err := pk.serializePrivateKey(priv) err := pk.serializePrivateKey(priv)
if err != nil { if err != nil {
@@ -500,32 +687,50 @@ func (pk *PrivateKey) encrypt(key []byte, params *s2k.Params, cipherFunction Cip
} }
privateKeyBytes := priv.Bytes() privateKeyBytes := priv.Bytes()
pk.sha1Checksum = true pk.s2kType = s2kType
block := pk.cipher.new(key) block := pk.cipher.new(key)
pk.iv = make([]byte, pk.cipher.blockSize()) switch s2kType {
_, err = rand.Read(pk.iv) case S2KAEAD:
if err != nil { if pk.aead == 0 {
return err return errors.StructuralError("aead mode is not set on key")
}
cfb := cipher.NewCFBEncrypter(block, pk.iv)
if pk.sha1Checksum {
pk.s2kType = S2KSHA1
h := sha1.New()
h.Write(privateKeyBytes)
sum := h.Sum(nil)
privateKeyBytes = append(privateKeyBytes, sum...)
} else {
pk.s2kType = S2KCHECKSUM
var sum uint16
for _, b := range privateKeyBytes {
sum += uint16(b)
} }
priv.Write([]byte{uint8(sum >> 8), uint8(sum)}) aead := pk.aead.new(block)
additionalData, err := pk.additionalData()
if err != nil {
return err
}
pk.iv = make([]byte, aead.NonceSize())
_, err = io.ReadFull(rand, pk.iv)
if err != nil {
return err
}
// Decrypt the encrypted key material with aead
pk.encryptedData = aead.Seal(nil, pk.iv, privateKeyBytes, additionalData)
case S2KSHA1, S2KCHECKSUM:
pk.iv = make([]byte, pk.cipher.blockSize())
_, err = io.ReadFull(rand, pk.iv)
if err != nil {
return err
}
cfb := cipher.NewCFBEncrypter(block, pk.iv)
if s2kType == S2KSHA1 {
h := sha1.New()
h.Write(privateKeyBytes)
sum := h.Sum(nil)
privateKeyBytes = append(privateKeyBytes, sum...)
} else {
var sum uint16
for _, b := range privateKeyBytes {
sum += uint16(b)
}
privateKeyBytes = append(privateKeyBytes, []byte{uint8(sum >> 8), uint8(sum)}...)
}
pk.encryptedData = make([]byte, len(privateKeyBytes))
cfb.XORKeyStream(pk.encryptedData, privateKeyBytes)
default:
return errors.InvalidArgumentError("invalid s2k type for encryption")
} }
pk.encryptedData = make([]byte, len(privateKeyBytes))
cfb.XORKeyStream(pk.encryptedData, privateKeyBytes)
pk.Encrypted = true pk.Encrypted = true
pk.PrivateKey = nil pk.PrivateKey = nil
return err return err
@@ -544,8 +749,15 @@ func (pk *PrivateKey) EncryptWithConfig(passphrase []byte, config *Config) error
return err return err
} }
s2k(key, passphrase) s2k(key, passphrase)
s2kType := S2KSHA1
if config.AEAD() != nil {
s2kType = S2KAEAD
pk.aead = config.AEAD().Mode()
pk.cipher = config.Cipher()
key = pk.applyHKDF(key)
}
// Encrypt the private key with the derived encryption key. // Encrypt the private key with the derived encryption key.
return pk.encrypt(key, params, config.Cipher()) return pk.encrypt(key, params, s2kType, config.Cipher(), config.Random())
} }
// EncryptPrivateKeys encrypts all unencrypted keys with the given config and passphrase. // EncryptPrivateKeys encrypts all unencrypted keys with the given config and passphrase.
@@ -564,7 +776,16 @@ func EncryptPrivateKeys(keys []*PrivateKey, passphrase []byte, config *Config) e
s2k(encryptionKey, passphrase) s2k(encryptionKey, passphrase)
for _, key := range keys { for _, key := range keys {
if key != nil && !key.Dummy() && !key.Encrypted { if key != nil && !key.Dummy() && !key.Encrypted {
err = key.encrypt(encryptionKey, params, config.Cipher()) s2kType := S2KSHA1
if config.AEAD() != nil {
s2kType = S2KAEAD
key.aead = config.AEAD().Mode()
key.cipher = config.Cipher()
derivedKey := key.applyHKDF(encryptionKey)
err = key.encrypt(derivedKey, params, s2kType, config.Cipher(), config.Random())
} else {
err = key.encrypt(encryptionKey, params, s2kType, config.Cipher(), config.Random())
}
if err != nil { if err != nil {
return err return err
} }
@@ -581,7 +802,7 @@ func (pk *PrivateKey) Encrypt(passphrase []byte) error {
S2KMode: s2k.IteratedSaltedS2K, S2KMode: s2k.IteratedSaltedS2K,
S2KCount: 65536, S2KCount: 65536,
Hash: crypto.SHA256, Hash: crypto.SHA256,
} , },
DefaultCipher: CipherAES256, DefaultCipher: CipherAES256,
} }
return pk.EncryptWithConfig(passphrase, config) return pk.EncryptWithConfig(passphrase, config)
@@ -601,6 +822,14 @@ func (pk *PrivateKey) serializePrivateKey(w io.Writer) (err error) {
err = serializeEdDSAPrivateKey(w, priv) err = serializeEdDSAPrivateKey(w, priv)
case *ecdh.PrivateKey: case *ecdh.PrivateKey:
err = serializeECDHPrivateKey(w, priv) err = serializeECDHPrivateKey(w, priv)
case *x25519.PrivateKey:
err = serializeX25519PrivateKey(w, priv)
case *x448.PrivateKey:
err = serializeX448PrivateKey(w, priv)
case *ed25519.PrivateKey:
err = serializeEd25519PrivateKey(w, priv)
case *ed448.PrivateKey:
err = serializeEd448PrivateKey(w, priv)
default: default:
err = errors.InvalidArgumentError("unknown private key type") err = errors.InvalidArgumentError("unknown private key type")
} }
@@ -621,8 +850,18 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err error) {
return pk.parseECDHPrivateKey(data) return pk.parseECDHPrivateKey(data)
case PubKeyAlgoEdDSA: case PubKeyAlgoEdDSA:
return pk.parseEdDSAPrivateKey(data) return pk.parseEdDSAPrivateKey(data)
case PubKeyAlgoX25519:
return pk.parseX25519PrivateKey(data)
case PubKeyAlgoX448:
return pk.parseX448PrivateKey(data)
case PubKeyAlgoEd25519:
return pk.parseEd25519PrivateKey(data)
case PubKeyAlgoEd448:
return pk.parseEd448PrivateKey(data)
default:
err = errors.StructuralError("unknown private key type")
return
} }
panic("impossible")
} }
func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err error) { func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err error) {
@@ -743,6 +982,86 @@ func (pk *PrivateKey) parseECDHPrivateKey(data []byte) (err error) {
return nil return nil
} }
func (pk *PrivateKey) parseX25519PrivateKey(data []byte) (err error) {
publicKey := pk.PublicKey.PublicKey.(*x25519.PublicKey)
privateKey := x25519.NewPrivateKey(*publicKey)
privateKey.PublicKey = *publicKey
privateKey.Secret = make([]byte, x25519.KeySize)
if len(data) != x25519.KeySize {
err = errors.StructuralError("wrong x25519 key size")
return err
}
subtle.ConstantTimeCopy(1, privateKey.Secret, data)
if err = x25519.Validate(privateKey); err != nil {
return err
}
pk.PrivateKey = privateKey
return nil
}
func (pk *PrivateKey) parseX448PrivateKey(data []byte) (err error) {
publicKey := pk.PublicKey.PublicKey.(*x448.PublicKey)
privateKey := x448.NewPrivateKey(*publicKey)
privateKey.PublicKey = *publicKey
privateKey.Secret = make([]byte, x448.KeySize)
if len(data) != x448.KeySize {
err = errors.StructuralError("wrong x448 key size")
return err
}
subtle.ConstantTimeCopy(1, privateKey.Secret, data)
if err = x448.Validate(privateKey); err != nil {
return err
}
pk.PrivateKey = privateKey
return nil
}
func (pk *PrivateKey) parseEd25519PrivateKey(data []byte) (err error) {
publicKey := pk.PublicKey.PublicKey.(*ed25519.PublicKey)
privateKey := ed25519.NewPrivateKey(*publicKey)
privateKey.PublicKey = *publicKey
if len(data) != ed25519.SeedSize {
err = errors.StructuralError("wrong ed25519 key size")
return err
}
err = privateKey.UnmarshalByteSecret(data)
if err != nil {
return err
}
err = ed25519.Validate(privateKey)
if err != nil {
return err
}
pk.PrivateKey = privateKey
return nil
}
func (pk *PrivateKey) parseEd448PrivateKey(data []byte) (err error) {
publicKey := pk.PublicKey.PublicKey.(*ed448.PublicKey)
privateKey := ed448.NewPrivateKey(*publicKey)
privateKey.PublicKey = *publicKey
if len(data) != ed448.SeedSize {
err = errors.StructuralError("wrong ed448 key size")
return err
}
err = privateKey.UnmarshalByteSecret(data)
if err != nil {
return err
}
err = ed448.Validate(privateKey)
if err != nil {
return err
}
pk.PrivateKey = privateKey
return nil
}
func (pk *PrivateKey) parseEdDSAPrivateKey(data []byte) (err error) { func (pk *PrivateKey) parseEdDSAPrivateKey(data []byte) (err error) {
eddsaPub := pk.PublicKey.PublicKey.(*eddsa.PublicKey) eddsaPub := pk.PublicKey.PublicKey.(*eddsa.PublicKey)
eddsaPriv := eddsa.NewPrivateKey(*eddsaPub) eddsaPriv := eddsa.NewPrivateKey(*eddsaPub)
@@ -767,6 +1086,41 @@ func (pk *PrivateKey) parseEdDSAPrivateKey(data []byte) (err error) {
return nil return nil
} }
func (pk *PrivateKey) additionalData() ([]byte, error) {
additionalData := bytes.NewBuffer(nil)
// Write additional data prefix based on packet type
var packetByte byte
if pk.PublicKey.IsSubkey {
packetByte = 0xc7
} else {
packetByte = 0xc5
}
// Write public key to additional data
_, err := additionalData.Write([]byte{packetByte})
if err != nil {
return nil, err
}
err = pk.PublicKey.serializeWithoutHeaders(additionalData)
if err != nil {
return nil, err
}
return additionalData.Bytes(), nil
}
func (pk *PrivateKey) applyHKDF(inputKey []byte) []byte {
var packetByte byte
if pk.PublicKey.IsSubkey {
packetByte = 0xc7
} else {
packetByte = 0xc5
}
associatedData := []byte{packetByte, byte(pk.Version), byte(pk.cipher), byte(pk.aead)}
hkdfReader := hkdf.New(sha256.New, inputKey, []byte{}, associatedData)
encryptionKey := make([]byte, pk.cipher.KeySize())
_, _ = readFull(hkdfReader, encryptionKey)
return encryptionKey
}
func validateDSAParameters(priv *dsa.PrivateKey) error { func validateDSAParameters(priv *dsa.PrivateKey) error {
p := priv.P // group prime p := priv.P // group prime
q := priv.Q // subgroup order q := priv.Q // subgroup order

View File

@@ -5,7 +5,6 @@
package packet package packet
import ( import (
"crypto"
"crypto/dsa" "crypto/dsa"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
@@ -21,23 +20,24 @@ import (
"github.com/ProtonMail/go-crypto/openpgp/ecdh" "github.com/ProtonMail/go-crypto/openpgp/ecdh"
"github.com/ProtonMail/go-crypto/openpgp/ecdsa" "github.com/ProtonMail/go-crypto/openpgp/ecdsa"
"github.com/ProtonMail/go-crypto/openpgp/ed25519"
"github.com/ProtonMail/go-crypto/openpgp/ed448"
"github.com/ProtonMail/go-crypto/openpgp/eddsa" "github.com/ProtonMail/go-crypto/openpgp/eddsa"
"github.com/ProtonMail/go-crypto/openpgp/elgamal" "github.com/ProtonMail/go-crypto/openpgp/elgamal"
"github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/errors"
"github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm"
"github.com/ProtonMail/go-crypto/openpgp/internal/ecc" "github.com/ProtonMail/go-crypto/openpgp/internal/ecc"
"github.com/ProtonMail/go-crypto/openpgp/internal/encoding" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding"
"github.com/ProtonMail/go-crypto/openpgp/x25519"
"github.com/ProtonMail/go-crypto/openpgp/x448"
) )
type kdfHashFunction byte
type kdfAlgorithm byte
// PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2. // PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2.
type PublicKey struct { type PublicKey struct {
Version int Version int
CreationTime time.Time CreationTime time.Time
PubKeyAlgo PublicKeyAlgorithm PubKeyAlgo PublicKeyAlgorithm
PublicKey interface{} // *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey or *eddsa.PublicKey PublicKey interface{} // *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey or *eddsa.PublicKey, *x25519.PublicKey, *x448.PublicKey, *ed25519.PublicKey, *ed448.PublicKey
Fingerprint []byte Fingerprint []byte
KeyId uint64 KeyId uint64
IsSubkey bool IsSubkey bool
@@ -61,11 +61,19 @@ func (pk *PublicKey) UpgradeToV5() {
pk.setFingerprintAndKeyId() pk.setFingerprintAndKeyId()
} }
// UpgradeToV6 updates the version of the key to v6, and updates all necessary
// fields.
func (pk *PublicKey) UpgradeToV6() error {
pk.Version = 6
pk.setFingerprintAndKeyId()
return pk.checkV6Compatibility()
}
// signingKey provides a convenient abstraction over signature verification // signingKey provides a convenient abstraction over signature verification
// for v3 and v4 public keys. // for v3 and v4 public keys.
type signingKey interface { type signingKey interface {
SerializeForHash(io.Writer) error SerializeForHash(io.Writer) error
SerializeSignaturePrefix(io.Writer) SerializeSignaturePrefix(io.Writer) error
serializeWithoutHeaders(io.Writer) error serializeWithoutHeaders(io.Writer) error
} }
@@ -174,6 +182,54 @@ func NewEdDSAPublicKey(creationTime time.Time, pub *eddsa.PublicKey) *PublicKey
return pk return pk
} }
func NewX25519PublicKey(creationTime time.Time, pub *x25519.PublicKey) *PublicKey {
pk := &PublicKey{
Version: 4,
CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoX25519,
PublicKey: pub,
}
pk.setFingerprintAndKeyId()
return pk
}
func NewX448PublicKey(creationTime time.Time, pub *x448.PublicKey) *PublicKey {
pk := &PublicKey{
Version: 4,
CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoX448,
PublicKey: pub,
}
pk.setFingerprintAndKeyId()
return pk
}
func NewEd25519PublicKey(creationTime time.Time, pub *ed25519.PublicKey) *PublicKey {
pk := &PublicKey{
Version: 4,
CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoEd25519,
PublicKey: pub,
}
pk.setFingerprintAndKeyId()
return pk
}
func NewEd448PublicKey(creationTime time.Time, pub *ed448.PublicKey) *PublicKey {
pk := &PublicKey{
Version: 4,
CreationTime: creationTime,
PubKeyAlgo: PubKeyAlgoEd448,
PublicKey: pub,
}
pk.setFingerprintAndKeyId()
return pk
}
func (pk *PublicKey) parse(r io.Reader) (err error) { func (pk *PublicKey) parse(r io.Reader) (err error) {
// RFC 4880, section 5.5.2 // RFC 4880, section 5.5.2
var buf [6]byte var buf [6]byte
@@ -181,12 +237,19 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
if err != nil { if err != nil {
return return
} }
if buf[0] != 4 && buf[0] != 5 {
pk.Version = int(buf[0])
if pk.Version != 4 && pk.Version != 5 && pk.Version != 6 {
return errors.UnsupportedError("public key version " + strconv.Itoa(int(buf[0]))) return errors.UnsupportedError("public key version " + strconv.Itoa(int(buf[0])))
} }
pk.Version = int(buf[0]) if V5Disabled && pk.Version == 5 {
if pk.Version == 5 { return errors.UnsupportedError("support for parsing v5 entities is disabled; build with `-tags v5` if needed")
}
if pk.Version >= 5 {
// Read the four-octet scalar octet count
// The count is not used in this implementation
var n [4]byte var n [4]byte
_, err = readFull(r, n[:]) _, err = readFull(r, n[:])
if err != nil { if err != nil {
@@ -195,6 +258,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
} }
pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0) pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0)
pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5]) pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5])
// Ignore four-ocet length
switch pk.PubKeyAlgo { switch pk.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly: case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
err = pk.parseRSA(r) err = pk.parseRSA(r)
@@ -208,6 +272,14 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
err = pk.parseECDH(r) err = pk.parseECDH(r)
case PubKeyAlgoEdDSA: case PubKeyAlgoEdDSA:
err = pk.parseEdDSA(r) err = pk.parseEdDSA(r)
case PubKeyAlgoX25519:
err = pk.parseX25519(r)
case PubKeyAlgoX448:
err = pk.parseX448(r)
case PubKeyAlgoEd25519:
err = pk.parseEd25519(r)
case PubKeyAlgoEd448:
err = pk.parseEd448(r)
default: default:
err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo))) err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo)))
} }
@@ -221,21 +293,44 @@ func (pk *PublicKey) parse(r io.Reader) (err error) {
func (pk *PublicKey) setFingerprintAndKeyId() { func (pk *PublicKey) setFingerprintAndKeyId() {
// RFC 4880, section 12.2 // RFC 4880, section 12.2
if pk.Version == 5 { if pk.Version >= 5 {
fingerprint := sha256.New() fingerprint := sha256.New()
pk.SerializeForHash(fingerprint) if err := pk.SerializeForHash(fingerprint); err != nil {
// Should not happen for a hash.
panic(err)
}
pk.Fingerprint = make([]byte, 32) pk.Fingerprint = make([]byte, 32)
copy(pk.Fingerprint, fingerprint.Sum(nil)) copy(pk.Fingerprint, fingerprint.Sum(nil))
pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[:8]) pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[:8])
} else { } else {
fingerprint := sha1.New() fingerprint := sha1.New()
pk.SerializeForHash(fingerprint) if err := pk.SerializeForHash(fingerprint); err != nil {
// Should not happen for a hash.
panic(err)
}
pk.Fingerprint = make([]byte, 20) pk.Fingerprint = make([]byte, 20)
copy(pk.Fingerprint, fingerprint.Sum(nil)) copy(pk.Fingerprint, fingerprint.Sum(nil))
pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[12:20]) pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[12:20])
} }
} }
func (pk *PublicKey) checkV6Compatibility() error {
// Implementations MUST NOT accept or generate version 6 key material using the deprecated OIDs.
switch pk.PubKeyAlgo {
case PubKeyAlgoECDH:
curveInfo := ecc.FindByOid(pk.oid)
if curveInfo == nil {
return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid))
}
if curveInfo.GenName == ecc.Curve25519GenName {
return errors.StructuralError("cannot generate v6 key with deprecated OID: Curve25519Legacy")
}
case PubKeyAlgoEdDSA:
return errors.StructuralError("cannot generate v6 key with deprecated algorithm: EdDSALegacy")
}
return nil
}
// parseRSA parses RSA public key material from the given Reader. See RFC 4880, // parseRSA parses RSA public key material from the given Reader. See RFC 4880,
// section 5.5.2. // section 5.5.2.
func (pk *PublicKey) parseRSA(r io.Reader) (err error) { func (pk *PublicKey) parseRSA(r io.Reader) (err error) {
@@ -324,16 +419,17 @@ func (pk *PublicKey) parseECDSA(r io.Reader) (err error) {
if _, err = pk.oid.ReadFrom(r); err != nil { if _, err = pk.oid.ReadFrom(r); err != nil {
return return
} }
pk.p = new(encoding.MPI)
if _, err = pk.p.ReadFrom(r); err != nil {
return
}
curveInfo := ecc.FindByOid(pk.oid) curveInfo := ecc.FindByOid(pk.oid)
if curveInfo == nil { if curveInfo == nil {
return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid)) return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid))
} }
pk.p = new(encoding.MPI)
if _, err = pk.p.ReadFrom(r); err != nil {
return
}
c, ok := curveInfo.Curve.(ecc.ECDSACurve) c, ok := curveInfo.Curve.(ecc.ECDSACurve)
if !ok { if !ok {
return errors.UnsupportedError(fmt.Sprintf("unsupported oid: %x", pk.oid)) return errors.UnsupportedError(fmt.Sprintf("unsupported oid: %x", pk.oid))
@@ -353,6 +449,17 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) {
if _, err = pk.oid.ReadFrom(r); err != nil { if _, err = pk.oid.ReadFrom(r); err != nil {
return return
} }
curveInfo := ecc.FindByOid(pk.oid)
if curveInfo == nil {
return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid))
}
if pk.Version == 6 && curveInfo.GenName == ecc.Curve25519GenName {
// Implementations MUST NOT accept or generate version 6 key material using the deprecated OIDs.
return errors.StructuralError("cannot read v6 key with deprecated OID: Curve25519Legacy")
}
pk.p = new(encoding.MPI) pk.p = new(encoding.MPI)
if _, err = pk.p.ReadFrom(r); err != nil { if _, err = pk.p.ReadFrom(r); err != nil {
return return
@@ -362,12 +469,6 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) {
return return
} }
curveInfo := ecc.FindByOid(pk.oid)
if curveInfo == nil {
return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid))
}
c, ok := curveInfo.Curve.(ecc.ECDHCurve) c, ok := curveInfo.Curve.(ecc.ECDHCurve)
if !ok { if !ok {
return errors.UnsupportedError(fmt.Sprintf("unsupported oid: %x", pk.oid)) return errors.UnsupportedError(fmt.Sprintf("unsupported oid: %x", pk.oid))
@@ -396,10 +497,16 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) {
} }
func (pk *PublicKey) parseEdDSA(r io.Reader) (err error) { func (pk *PublicKey) parseEdDSA(r io.Reader) (err error) {
if pk.Version == 6 {
// Implementations MUST NOT accept or generate version 6 key material using the deprecated OIDs.
return errors.StructuralError("cannot generate v6 key with deprecated algorithm: EdDSALegacy")
}
pk.oid = new(encoding.OID) pk.oid = new(encoding.OID)
if _, err = pk.oid.ReadFrom(r); err != nil { if _, err = pk.oid.ReadFrom(r); err != nil {
return return
} }
curveInfo := ecc.FindByOid(pk.oid) curveInfo := ecc.FindByOid(pk.oid)
if curveInfo == nil { if curveInfo == nil {
return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid)) return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid))
@@ -435,75 +542,145 @@ func (pk *PublicKey) parseEdDSA(r io.Reader) (err error) {
return return
} }
func (pk *PublicKey) parseX25519(r io.Reader) (err error) {
point := make([]byte, x25519.KeySize)
_, err = io.ReadFull(r, point)
if err != nil {
return
}
pub := &x25519.PublicKey{
Point: point,
}
pk.PublicKey = pub
return
}
func (pk *PublicKey) parseX448(r io.Reader) (err error) {
point := make([]byte, x448.KeySize)
_, err = io.ReadFull(r, point)
if err != nil {
return
}
pub := &x448.PublicKey{
Point: point,
}
pk.PublicKey = pub
return
}
func (pk *PublicKey) parseEd25519(r io.Reader) (err error) {
point := make([]byte, ed25519.PublicKeySize)
_, err = io.ReadFull(r, point)
if err != nil {
return
}
pub := &ed25519.PublicKey{
Point: point,
}
pk.PublicKey = pub
return
}
func (pk *PublicKey) parseEd448(r io.Reader) (err error) {
point := make([]byte, ed448.PublicKeySize)
_, err = io.ReadFull(r, point)
if err != nil {
return
}
pub := &ed448.PublicKey{
Point: point,
}
pk.PublicKey = pub
return
}
// SerializeForHash serializes the PublicKey to w with the special packet // SerializeForHash serializes the PublicKey to w with the special packet
// header format needed for hashing. // header format needed for hashing.
func (pk *PublicKey) SerializeForHash(w io.Writer) error { func (pk *PublicKey) SerializeForHash(w io.Writer) error {
pk.SerializeSignaturePrefix(w) if err := pk.SerializeSignaturePrefix(w); err != nil {
return err
}
return pk.serializeWithoutHeaders(w) return pk.serializeWithoutHeaders(w)
} }
// SerializeSignaturePrefix writes the prefix for this public key to the given Writer. // SerializeSignaturePrefix writes the prefix for this public key to the given Writer.
// The prefix is used when calculating a signature over this public key. See // The prefix is used when calculating a signature over this public key. See
// RFC 4880, section 5.2.4. // RFC 4880, section 5.2.4.
func (pk *PublicKey) SerializeSignaturePrefix(w io.Writer) { func (pk *PublicKey) SerializeSignaturePrefix(w io.Writer) error {
var pLength = pk.algorithmSpecificByteCount() var pLength = pk.algorithmSpecificByteCount()
if pk.Version == 5 { // version, timestamp, algorithm
pLength += 10 // version, timestamp (4), algorithm, key octet count (4). pLength += versionSize + timestampSize + algorithmSize
w.Write([]byte{ if pk.Version >= 5 {
0x9A, // key octet count (4).
pLength += 4
_, err := w.Write([]byte{
// When a v4 signature is made over a key, the hash data starts with the octet 0x99, followed by a two-octet length
// of the key, and then the body of the key packet. When a v6 signature is made over a key, the hash data starts
// with the salt, then octet 0x9B, followed by a four-octet length of the key, and then the body of the key packet.
0x95 + byte(pk.Version),
byte(pLength >> 24), byte(pLength >> 24),
byte(pLength >> 16), byte(pLength >> 16),
byte(pLength >> 8), byte(pLength >> 8),
byte(pLength), byte(pLength),
}) })
return return err
} }
pLength += 6 if _, err := w.Write([]byte{0x99, byte(pLength >> 8), byte(pLength)}); err != nil {
w.Write([]byte{0x99, byte(pLength >> 8), byte(pLength)}) return err
}
return nil
} }
func (pk *PublicKey) Serialize(w io.Writer) (err error) { func (pk *PublicKey) Serialize(w io.Writer) (err error) {
length := 6 // 6 byte header length := uint32(versionSize + timestampSize + algorithmSize) // 6 byte header
length += pk.algorithmSpecificByteCount() length += pk.algorithmSpecificByteCount()
if pk.Version == 5 { if pk.Version >= 5 {
length += 4 // octet key count length += 4 // octet key count
} }
packetType := packetTypePublicKey packetType := packetTypePublicKey
if pk.IsSubkey { if pk.IsSubkey {
packetType = packetTypePublicSubkey packetType = packetTypePublicSubkey
} }
err = serializeHeader(w, packetType, length) err = serializeHeader(w, packetType, int(length))
if err != nil { if err != nil {
return return
} }
return pk.serializeWithoutHeaders(w) return pk.serializeWithoutHeaders(w)
} }
func (pk *PublicKey) algorithmSpecificByteCount() int { func (pk *PublicKey) algorithmSpecificByteCount() uint32 {
length := 0 length := uint32(0)
switch pk.PubKeyAlgo { switch pk.PubKeyAlgo {
case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly: case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly:
length += int(pk.n.EncodedLength()) length += uint32(pk.n.EncodedLength())
length += int(pk.e.EncodedLength()) length += uint32(pk.e.EncodedLength())
case PubKeyAlgoDSA: case PubKeyAlgoDSA:
length += int(pk.p.EncodedLength()) length += uint32(pk.p.EncodedLength())
length += int(pk.q.EncodedLength()) length += uint32(pk.q.EncodedLength())
length += int(pk.g.EncodedLength()) length += uint32(pk.g.EncodedLength())
length += int(pk.y.EncodedLength()) length += uint32(pk.y.EncodedLength())
case PubKeyAlgoElGamal: case PubKeyAlgoElGamal:
length += int(pk.p.EncodedLength()) length += uint32(pk.p.EncodedLength())
length += int(pk.g.EncodedLength()) length += uint32(pk.g.EncodedLength())
length += int(pk.y.EncodedLength()) length += uint32(pk.y.EncodedLength())
case PubKeyAlgoECDSA: case PubKeyAlgoECDSA:
length += int(pk.oid.EncodedLength()) length += uint32(pk.oid.EncodedLength())
length += int(pk.p.EncodedLength()) length += uint32(pk.p.EncodedLength())
case PubKeyAlgoECDH: case PubKeyAlgoECDH:
length += int(pk.oid.EncodedLength()) length += uint32(pk.oid.EncodedLength())
length += int(pk.p.EncodedLength()) length += uint32(pk.p.EncodedLength())
length += int(pk.kdf.EncodedLength()) length += uint32(pk.kdf.EncodedLength())
case PubKeyAlgoEdDSA: case PubKeyAlgoEdDSA:
length += int(pk.oid.EncodedLength()) length += uint32(pk.oid.EncodedLength())
length += int(pk.p.EncodedLength()) length += uint32(pk.p.EncodedLength())
case PubKeyAlgoX25519:
length += x25519.KeySize
case PubKeyAlgoX448:
length += x448.KeySize
case PubKeyAlgoEd25519:
length += ed25519.PublicKeySize
case PubKeyAlgoEd448:
length += ed448.PublicKeySize
default: default:
panic("unknown public key algorithm") panic("unknown public key algorithm")
} }
@@ -522,7 +699,7 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) {
return return
} }
if pk.Version == 5 { if pk.Version >= 5 {
n := pk.algorithmSpecificByteCount() n := pk.algorithmSpecificByteCount()
if _, err = w.Write([]byte{ if _, err = w.Write([]byte{
byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n), byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
@@ -580,6 +757,22 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) {
} }
_, err = w.Write(pk.p.EncodedBytes()) _, err = w.Write(pk.p.EncodedBytes())
return return
case PubKeyAlgoX25519:
publicKey := pk.PublicKey.(*x25519.PublicKey)
_, err = w.Write(publicKey.Point)
return
case PubKeyAlgoX448:
publicKey := pk.PublicKey.(*x448.PublicKey)
_, err = w.Write(publicKey.Point)
return
case PubKeyAlgoEd25519:
publicKey := pk.PublicKey.(*ed25519.PublicKey)
_, err = w.Write(publicKey.Point)
return
case PubKeyAlgoEd448:
publicKey := pk.PublicKey.(*ed448.PublicKey)
_, err = w.Write(publicKey.Point)
return
} }
return errors.InvalidArgumentError("bad public-key algorithm") return errors.InvalidArgumentError("bad public-key algorithm")
} }
@@ -589,6 +782,20 @@ func (pk *PublicKey) CanSign() bool {
return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElGamal && pk.PubKeyAlgo != PubKeyAlgoECDH return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElGamal && pk.PubKeyAlgo != PubKeyAlgoECDH
} }
// VerifyHashTag returns nil iff sig appears to be a plausible signature of the data
// hashed into signed, based solely on its HashTag. signed is mutated by this call.
func VerifyHashTag(signed hash.Hash, sig *Signature) (err error) {
if sig.Version == 5 && (sig.SigType == 0x00 || sig.SigType == 0x01) {
sig.AddMetadataToHashSuffix()
}
signed.Write(sig.HashSuffix)
hashBytes := signed.Sum(nil)
if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] {
return errors.SignatureError("hash tag doesn't match")
}
return nil
}
// VerifySignature returns nil iff sig is a valid signature, made by this // VerifySignature returns nil iff sig is a valid signature, made by this
// public key, of the data hashed into signed. signed is mutated by this call. // public key, of the data hashed into signed. signed is mutated by this call.
func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) { func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) {
@@ -600,7 +807,8 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
} }
signed.Write(sig.HashSuffix) signed.Write(sig.HashSuffix)
hashBytes := signed.Sum(nil) hashBytes := signed.Sum(nil)
if sig.Version == 5 && (hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1]) { // see discussion https://github.com/ProtonMail/go-crypto/issues/107
if sig.Version >= 5 && (hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1]) {
return errors.SignatureError("hash tag doesn't match") return errors.SignatureError("hash tag doesn't match")
} }
@@ -639,6 +847,18 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
return errors.SignatureError("EdDSA verification failure") return errors.SignatureError("EdDSA verification failure")
} }
return nil return nil
case PubKeyAlgoEd25519:
ed25519PublicKey := pk.PublicKey.(*ed25519.PublicKey)
if !ed25519.Verify(ed25519PublicKey, hashBytes, sig.EdSig) {
return errors.SignatureError("Ed25519 verification failure")
}
return nil
case PubKeyAlgoEd448:
ed448PublicKey := pk.PublicKey.(*ed448.PublicKey)
if !ed448.Verify(ed448PublicKey, hashBytes, sig.EdSig) {
return errors.SignatureError("ed448 verification failure")
}
return nil
default: default:
return errors.SignatureError("Unsupported public key algorithm used in signature") return errors.SignatureError("Unsupported public key algorithm used in signature")
} }
@@ -646,11 +866,8 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro
// keySignatureHash returns a Hash of the message that needs to be signed for // keySignatureHash returns a Hash of the message that needs to be signed for
// pk to assert a subkey relationship to signed. // pk to assert a subkey relationship to signed.
func keySignatureHash(pk, signed signingKey, hashFunc crypto.Hash) (h hash.Hash, err error) { func keySignatureHash(pk, signed signingKey, hashFunc hash.Hash) (h hash.Hash, err error) {
if !hashFunc.Available() { h = hashFunc
return nil, errors.UnsupportedError("hash function")
}
h = hashFunc.New()
// RFC 4880, section 5.2.4 // RFC 4880, section 5.2.4
err = pk.SerializeForHash(h) err = pk.SerializeForHash(h)
@@ -662,10 +879,28 @@ func keySignatureHash(pk, signed signingKey, hashFunc crypto.Hash) (h hash.Hash,
return return
} }
// VerifyKeyHashTag returns nil iff sig appears to be a plausible signature over this
// primary key and subkey, based solely on its HashTag.
func (pk *PublicKey) VerifyKeyHashTag(signed *PublicKey, sig *Signature) error {
preparedHash, err := sig.PrepareVerify()
if err != nil {
return err
}
h, err := keySignatureHash(pk, signed, preparedHash)
if err != nil {
return err
}
return VerifyHashTag(h, sig)
}
// VerifyKeySignature returns nil iff sig is a valid signature, made by this // VerifyKeySignature returns nil iff sig is a valid signature, made by this
// public key, of signed. // public key, of signed.
func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error { func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error {
h, err := keySignatureHash(pk, signed, sig.Hash) preparedHash, err := sig.PrepareVerify()
if err != nil {
return err
}
h, err := keySignatureHash(pk, signed, preparedHash)
if err != nil { if err != nil {
return err return err
} }
@@ -679,10 +914,14 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error
if sig.EmbeddedSignature == nil { if sig.EmbeddedSignature == nil {
return errors.StructuralError("signing subkey is missing cross-signature") return errors.StructuralError("signing subkey is missing cross-signature")
} }
preparedHashEmbedded, err := sig.EmbeddedSignature.PrepareVerify()
if err != nil {
return err
}
// Verify the cross-signature. This is calculated over the same // Verify the cross-signature. This is calculated over the same
// data as the main signature, so we cannot just recursively // data as the main signature, so we cannot just recursively
// call signed.VerifyKeySignature(...) // call signed.VerifyKeySignature(...)
if h, err = keySignatureHash(pk, signed, sig.EmbeddedSignature.Hash); err != nil { if h, err = keySignatureHash(pk, signed, preparedHashEmbedded); err != nil {
return errors.StructuralError("error while hashing for cross-signature: " + err.Error()) return errors.StructuralError("error while hashing for cross-signature: " + err.Error())
} }
if err := signed.VerifySignature(h, sig.EmbeddedSignature); err != nil { if err := signed.VerifySignature(h, sig.EmbeddedSignature); err != nil {
@@ -693,32 +932,44 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error
return nil return nil
} }
func keyRevocationHash(pk signingKey, hashFunc crypto.Hash) (h hash.Hash, err error) { func keyRevocationHash(pk signingKey, hashFunc hash.Hash) (err error) {
if !hashFunc.Available() { return pk.SerializeForHash(hashFunc)
return nil, errors.UnsupportedError("hash function") }
// VerifyRevocationHashTag returns nil iff sig appears to be a plausible signature
// over this public key, based solely on its HashTag.
func (pk *PublicKey) VerifyRevocationHashTag(sig *Signature) (err error) {
preparedHash, err := sig.PrepareVerify()
if err != nil {
return err
} }
h = hashFunc.New() if err = keyRevocationHash(pk, preparedHash); err != nil {
return err
// RFC 4880, section 5.2.4 }
err = pk.SerializeForHash(h) return VerifyHashTag(preparedHash, sig)
return
} }
// VerifyRevocationSignature returns nil iff sig is a valid signature, made by this // VerifyRevocationSignature returns nil iff sig is a valid signature, made by this
// public key. // public key.
func (pk *PublicKey) VerifyRevocationSignature(sig *Signature) (err error) { func (pk *PublicKey) VerifyRevocationSignature(sig *Signature) (err error) {
h, err := keyRevocationHash(pk, sig.Hash) preparedHash, err := sig.PrepareVerify()
if err != nil { if err != nil {
return err return err
} }
return pk.VerifySignature(h, sig) if err = keyRevocationHash(pk, preparedHash); err != nil {
return err
}
return pk.VerifySignature(preparedHash, sig)
} }
// VerifySubkeyRevocationSignature returns nil iff sig is a valid subkey revocation signature, // VerifySubkeyRevocationSignature returns nil iff sig is a valid subkey revocation signature,
// made by this public key, of signed. // made by this public key, of signed.
func (pk *PublicKey) VerifySubkeyRevocationSignature(sig *Signature, signed *PublicKey) (err error) { func (pk *PublicKey) VerifySubkeyRevocationSignature(sig *Signature, signed *PublicKey) (err error) {
h, err := keySignatureHash(pk, signed, sig.Hash) preparedHash, err := sig.PrepareVerify()
if err != nil {
return err
}
h, err := keySignatureHash(pk, signed, preparedHash)
if err != nil { if err != nil {
return err return err
} }
@@ -727,15 +978,15 @@ func (pk *PublicKey) VerifySubkeyRevocationSignature(sig *Signature, signed *Pub
// userIdSignatureHash returns a Hash of the message that needs to be signed // userIdSignatureHash returns a Hash of the message that needs to be signed
// to assert that pk is a valid key for id. // to assert that pk is a valid key for id.
func userIdSignatureHash(id string, pk *PublicKey, hashFunc crypto.Hash) (h hash.Hash, err error) { func userIdSignatureHash(id string, pk *PublicKey, h hash.Hash) (err error) {
if !hashFunc.Available() {
return nil, errors.UnsupportedError("hash function")
}
h = hashFunc.New()
// RFC 4880, section 5.2.4 // RFC 4880, section 5.2.4
pk.SerializeSignaturePrefix(h) if err := pk.SerializeSignaturePrefix(h); err != nil {
pk.serializeWithoutHeaders(h) return err
}
if err := pk.serializeWithoutHeaders(h); err != nil {
return err
}
var buf [5]byte var buf [5]byte
buf[0] = 0xb4 buf[0] = 0xb4
@@ -746,28 +997,68 @@ func userIdSignatureHash(id string, pk *PublicKey, hashFunc crypto.Hash) (h hash
h.Write(buf[:]) h.Write(buf[:])
h.Write([]byte(id)) h.Write([]byte(id))
return return nil
}
// directKeySignatureHash returns a Hash of the message that needs to be signed.
func directKeySignatureHash(pk *PublicKey, h hash.Hash) (err error) {
return pk.SerializeForHash(h)
}
// VerifyUserIdHashTag returns nil iff sig appears to be a plausible signature over this
// public key and UserId, based solely on its HashTag
func (pk *PublicKey) VerifyUserIdHashTag(id string, sig *Signature) (err error) {
preparedHash, err := sig.PrepareVerify()
if err != nil {
return err
}
err = userIdSignatureHash(id, pk, preparedHash)
if err != nil {
return err
}
return VerifyHashTag(preparedHash, sig)
} }
// VerifyUserIdSignature returns nil iff sig is a valid signature, made by this // VerifyUserIdSignature returns nil iff sig is a valid signature, made by this
// public key, that id is the identity of pub. // public key, that id is the identity of pub.
func (pk *PublicKey) VerifyUserIdSignature(id string, pub *PublicKey, sig *Signature) (err error) { func (pk *PublicKey) VerifyUserIdSignature(id string, pub *PublicKey, sig *Signature) (err error) {
h, err := userIdSignatureHash(id, pub, sig.Hash) h, err := sig.PrepareVerify()
if err != nil { if err != nil {
return err return err
} }
if err := userIdSignatureHash(id, pub, h); err != nil {
return err
}
return pk.VerifySignature(h, sig)
}
// VerifyDirectKeySignature returns nil iff sig is a valid signature, made by this
// public key.
func (pk *PublicKey) VerifyDirectKeySignature(sig *Signature) (err error) {
h, err := sig.PrepareVerify()
if err != nil {
return err
}
if err := directKeySignatureHash(pk, h); err != nil {
return err
}
return pk.VerifySignature(h, sig) return pk.VerifySignature(h, sig)
} }
// KeyIdString returns the public key's fingerprint in capital hex // KeyIdString returns the public key's fingerprint in capital hex
// (e.g. "6C7EE1B8621CC013"). // (e.g. "6C7EE1B8621CC013").
func (pk *PublicKey) KeyIdString() string { func (pk *PublicKey) KeyIdString() string {
return fmt.Sprintf("%X", pk.Fingerprint[12:20]) return fmt.Sprintf("%016X", pk.KeyId)
} }
// KeyIdShortString returns the short form of public key's fingerprint // KeyIdShortString returns the short form of public key's fingerprint
// in capital hex, as shown by gpg --list-keys (e.g. "621CC013"). // in capital hex, as shown by gpg --list-keys (e.g. "621CC013").
// This function will return the full key id for v5 and v6 keys
// since the short key id is undefined for them.
func (pk *PublicKey) KeyIdShortString() string { func (pk *PublicKey) KeyIdShortString() string {
if pk.Version >= 5 {
return pk.KeyIdString()
}
return fmt.Sprintf("%X", pk.Fingerprint[16:20]) return fmt.Sprintf("%X", pk.Fingerprint[16:20])
} }
@@ -786,21 +1077,49 @@ func (pk *PublicKey) BitLength() (bitLength uint16, err error) {
bitLength = pk.p.BitLength() bitLength = pk.p.BitLength()
case PubKeyAlgoEdDSA: case PubKeyAlgoEdDSA:
bitLength = pk.p.BitLength() bitLength = pk.p.BitLength()
case PubKeyAlgoX25519:
bitLength = x25519.KeySize * 8
case PubKeyAlgoX448:
bitLength = x448.KeySize * 8
case PubKeyAlgoEd25519:
bitLength = ed25519.PublicKeySize * 8
case PubKeyAlgoEd448:
bitLength = ed448.PublicKeySize * 8
default: default:
err = errors.InvalidArgumentError("bad public-key algorithm") err = errors.InvalidArgumentError("bad public-key algorithm")
} }
return return
} }
// Curve returns the used elliptic curve of this public key.
// Returns an error if no elliptic curve is used.
func (pk *PublicKey) Curve() (curve Curve, err error) {
switch pk.PubKeyAlgo {
case PubKeyAlgoECDSA, PubKeyAlgoECDH, PubKeyAlgoEdDSA:
curveInfo := ecc.FindByOid(pk.oid)
if curveInfo == nil {
return "", errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid))
}
curve = Curve(curveInfo.GenName)
case PubKeyAlgoEd25519, PubKeyAlgoX25519:
curve = Curve25519
case PubKeyAlgoEd448, PubKeyAlgoX448:
curve = Curve448
default:
err = errors.InvalidArgumentError("public key does not operate with an elliptic curve")
}
return
}
// KeyExpired returns whether sig is a self-signature of a key that has // KeyExpired returns whether sig is a self-signature of a key that has
// expired or is created in the future. // expired or is created in the future.
func (pk *PublicKey) KeyExpired(sig *Signature, currentTime time.Time) bool { func (pk *PublicKey) KeyExpired(sig *Signature, currentTime time.Time) bool {
if pk.CreationTime.After(currentTime) { if pk.CreationTime.Unix() > currentTime.Unix() {
return true return true
} }
if sig.KeyLifetimeSecs == nil || *sig.KeyLifetimeSecs == 0 { if sig.KeyLifetimeSecs == nil || *sig.KeyLifetimeSecs == 0 {
return false return false
} }
expiry := pk.CreationTime.Add(time.Duration(*sig.KeyLifetimeSecs) * time.Second) expiry := pk.CreationTime.Add(time.Duration(*sig.KeyLifetimeSecs) * time.Second)
return currentTime.After(expiry) return currentTime.Unix() > expiry.Unix()
} }

View File

@@ -10,6 +10,12 @@ import (
"github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/errors"
) )
type PacketReader interface {
Next() (p Packet, err error)
Push(reader io.Reader) (err error)
Unread(p Packet)
}
// Reader reads packets from an io.Reader and allows packets to be 'unread' so // Reader reads packets from an io.Reader and allows packets to be 'unread' so
// that they result from the next call to Next. // that they result from the next call to Next.
type Reader struct { type Reader struct {
@@ -26,37 +32,81 @@ type Reader struct {
const maxReaders = 32 const maxReaders = 32
// Next returns the most recently unread Packet, or reads another packet from // Next returns the most recently unread Packet, or reads another packet from
// the top-most io.Reader. Unknown packet types are skipped. // the top-most io.Reader. Unknown/unsupported/Marker packet types are skipped.
func (r *Reader) Next() (p Packet, err error) { func (r *Reader) Next() (p Packet, err error) {
for {
p, err := r.read()
if err == io.EOF {
break
} else if err != nil {
if _, ok := err.(errors.UnknownPacketTypeError); ok {
continue
}
if _, ok := err.(errors.UnsupportedError); ok {
switch p.(type) {
case *SymmetricallyEncrypted, *AEADEncrypted, *Compressed, *LiteralData:
return nil, err
}
continue
}
return nil, err
} else {
//A marker packet MUST be ignored when received
switch p.(type) {
case *Marker:
continue
}
return p, nil
}
}
return nil, io.EOF
}
// Next returns the most recently unread Packet, or reads another packet from
// the top-most io.Reader. Unknown/Marker packet types are skipped while unsupported
// packets are returned as UnsupportedPacket type.
func (r *Reader) NextWithUnsupported() (p Packet, err error) {
for {
p, err = r.read()
if err == io.EOF {
break
} else if err != nil {
if _, ok := err.(errors.UnknownPacketTypeError); ok {
continue
}
if casteErr, ok := err.(errors.UnsupportedError); ok {
return &UnsupportedPacket{
IncompletePacket: p,
Error: casteErr,
}, nil
}
return
} else {
//A marker packet MUST be ignored when received
switch p.(type) {
case *Marker:
continue
}
return
}
}
return nil, io.EOF
}
func (r *Reader) read() (p Packet, err error) {
if len(r.q) > 0 { if len(r.q) > 0 {
p = r.q[len(r.q)-1] p = r.q[len(r.q)-1]
r.q = r.q[:len(r.q)-1] r.q = r.q[:len(r.q)-1]
return return
} }
for len(r.readers) > 0 { for len(r.readers) > 0 {
p, err = Read(r.readers[len(r.readers)-1]) p, err = Read(r.readers[len(r.readers)-1])
if err == nil {
return
}
if err == io.EOF { if err == io.EOF {
r.readers = r.readers[:len(r.readers)-1] r.readers = r.readers[:len(r.readers)-1]
continue continue
} }
// TODO: Add strict mode that rejects unknown packets, instead of ignoring them. return p, err
if _, ok := err.(errors.UnknownPacketTypeError); ok {
continue
}
if _, ok := err.(errors.UnsupportedError); ok {
switch p.(type) {
case *SymmetricallyEncrypted, *AEADEncrypted, *Compressed, *LiteralData:
return nil, err
}
continue
}
return nil, err
} }
return nil, io.EOF return nil, io.EOF
} }
@@ -84,3 +134,76 @@ func NewReader(r io.Reader) *Reader {
readers: []io.Reader{r}, readers: []io.Reader{r},
} }
} }
// CheckReader is similar to Reader but additionally
// uses the pushdown automata to verify the read packet sequence.
type CheckReader struct {
Reader
verifier *SequenceVerifier
fullyRead bool
}
// Next returns the most recently unread Packet, or reads another packet from
// the top-most io.Reader. Unknown packet types are skipped.
// If the read packet sequence does not conform to the packet composition
// rules in rfc4880, it returns an error.
func (r *CheckReader) Next() (p Packet, err error) {
if r.fullyRead {
return nil, io.EOF
}
if len(r.q) > 0 {
p = r.q[len(r.q)-1]
r.q = r.q[:len(r.q)-1]
return
}
var errMsg error
for len(r.readers) > 0 {
p, errMsg, err = ReadWithCheck(r.readers[len(r.readers)-1], r.verifier)
if errMsg != nil {
err = errMsg
return
}
if err == nil {
return
}
if err == io.EOF {
r.readers = r.readers[:len(r.readers)-1]
continue
}
//A marker packet MUST be ignored when received
switch p.(type) {
case *Marker:
continue
}
if _, ok := err.(errors.UnknownPacketTypeError); ok {
continue
}
if _, ok := err.(errors.UnsupportedError); ok {
switch p.(type) {
case *SymmetricallyEncrypted, *AEADEncrypted, *Compressed, *LiteralData:
return nil, err
}
continue
}
return nil, err
}
if errMsg = r.verifier.Next(EOSSymbol); errMsg != nil {
return nil, errMsg
}
if errMsg = r.verifier.AssertValid(); errMsg != nil {
return nil, errMsg
}
r.fullyRead = true
return nil, io.EOF
}
func NewCheckReader(r io.Reader) *CheckReader {
return &CheckReader{
Reader: Reader{
q: nil,
readers: []io.Reader{r},
},
verifier: NewSequenceVerifier(),
fullyRead: false,
}
}

View File

@@ -0,0 +1,15 @@
package packet
// Recipient type represents a Intended Recipient Fingerprint subpacket
// See https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh#name-intended-recipient-fingerpr
type Recipient struct {
KeyVersion int
Fingerprint []byte
}
func (r *Recipient) Serialize() []byte {
packet := make([]byte, len(r.Fingerprint)+1)
packet[0] = byte(r.KeyVersion)
copy(packet[1:], r.Fingerprint)
return packet
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,11 +7,13 @@ package packet
import ( import (
"bytes" "bytes"
"crypto/cipher" "crypto/cipher"
"crypto/sha256"
"io" "io"
"strconv" "strconv"
"github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/errors"
"github.com/ProtonMail/go-crypto/openpgp/s2k" "github.com/ProtonMail/go-crypto/openpgp/s2k"
"golang.org/x/crypto/hkdf"
) )
// This is the largest session key that we'll support. Since at most 256-bit cipher // This is the largest session key that we'll support. Since at most 256-bit cipher
@@ -39,10 +41,21 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error {
return err return err
} }
ske.Version = int(buf[0]) ske.Version = int(buf[0])
if ske.Version != 4 && ske.Version != 5 { if ske.Version != 4 && ske.Version != 5 && ske.Version != 6 {
return errors.UnsupportedError("unknown SymmetricKeyEncrypted version") return errors.UnsupportedError("unknown SymmetricKeyEncrypted version")
} }
if V5Disabled && ske.Version == 5 {
return errors.UnsupportedError("support for parsing v5 entities is disabled; build with `-tags v5` if needed")
}
if ske.Version > 5 {
// Scalar octet count
if _, err := readFull(r, buf[:]); err != nil {
return err
}
}
// Cipher function // Cipher function
if _, err := readFull(r, buf[:]); err != nil { if _, err := readFull(r, buf[:]); err != nil {
return err return err
@@ -52,7 +65,7 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error {
return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[0]))) return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[0])))
} }
if ske.Version == 5 { if ske.Version >= 5 {
// AEAD mode // AEAD mode
if _, err := readFull(r, buf[:]); err != nil { if _, err := readFull(r, buf[:]); err != nil {
return errors.StructuralError("cannot read AEAD octet from packet") return errors.StructuralError("cannot read AEAD octet from packet")
@@ -60,6 +73,13 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error {
ske.Mode = AEADMode(buf[0]) ske.Mode = AEADMode(buf[0])
} }
if ske.Version > 5 {
// Scalar octet count
if _, err := readFull(r, buf[:]); err != nil {
return err
}
}
var err error var err error
if ske.s2k, err = s2k.Parse(r); err != nil { if ske.s2k, err = s2k.Parse(r); err != nil {
if _, ok := err.(errors.ErrDummyPrivateKey); ok { if _, ok := err.(errors.ErrDummyPrivateKey); ok {
@@ -68,7 +88,7 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error {
return err return err
} }
if ske.Version == 5 { if ske.Version >= 5 {
// AEAD IV // AEAD IV
iv := make([]byte, ske.Mode.IvLength()) iv := make([]byte, ske.Mode.IvLength())
_, err := readFull(r, iv) _, err := readFull(r, iv)
@@ -109,8 +129,8 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) ([]byte, CipherFunc
case 4: case 4:
plaintextKey, cipherFunc, err := ske.decryptV4(key) plaintextKey, cipherFunc, err := ske.decryptV4(key)
return plaintextKey, cipherFunc, err return plaintextKey, cipherFunc, err
case 5: case 5, 6:
plaintextKey, err := ske.decryptV5(key) plaintextKey, err := ske.aeadDecrypt(ske.Version, key)
return plaintextKey, CipherFunction(0), err return plaintextKey, CipherFunction(0), err
} }
err := errors.UnsupportedError("unknown SymmetricKeyEncrypted version") err := errors.UnsupportedError("unknown SymmetricKeyEncrypted version")
@@ -136,9 +156,9 @@ func (ske *SymmetricKeyEncrypted) decryptV4(key []byte) ([]byte, CipherFunction,
return plaintextKey, cipherFunc, nil return plaintextKey, cipherFunc, nil
} }
func (ske *SymmetricKeyEncrypted) decryptV5(key []byte) ([]byte, error) { func (ske *SymmetricKeyEncrypted) aeadDecrypt(version int, key []byte) ([]byte, error) {
adata := []byte{0xc3, byte(5), byte(ske.CipherFunc), byte(ske.Mode)} adata := []byte{0xc3, byte(version), byte(ske.CipherFunc), byte(ske.Mode)}
aead := getEncryptedKeyAeadInstance(ske.CipherFunc, ske.Mode, key, adata) aead := getEncryptedKeyAeadInstance(ske.CipherFunc, ske.Mode, key, adata, version)
plaintextKey, err := aead.Open(nil, ske.iv, ske.encryptedKey, adata) plaintextKey, err := aead.Open(nil, ske.iv, ske.encryptedKey, adata)
if err != nil { if err != nil {
@@ -175,10 +195,22 @@ func SerializeSymmetricKeyEncrypted(w io.Writer, passphrase []byte, config *Conf
// the given passphrase. The returned session key must be passed to // the given passphrase. The returned session key must be passed to
// SerializeSymmetricallyEncrypted. // SerializeSymmetricallyEncrypted.
// If config is nil, sensible defaults will be used. // If config is nil, sensible defaults will be used.
// Deprecated: Use SerializeSymmetricKeyEncryptedAEADReuseKey instead.
func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, passphrase []byte, config *Config) (err error) { func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, passphrase []byte, config *Config) (err error) {
return SerializeSymmetricKeyEncryptedAEADReuseKey(w, sessionKey, passphrase, config.AEAD() != nil, config)
}
// SerializeSymmetricKeyEncryptedAEADReuseKey serializes a symmetric key packet to w.
// The packet contains the given session key, encrypted by a key derived from
// the given passphrase. The returned session key must be passed to
// SerializeSymmetricallyEncrypted.
// If aeadSupported is set, SKESK v6 is used, otherwise v4.
// Note: aeadSupported MUST match the value passed to SerializeSymmetricallyEncrypted.
// If config is nil, sensible defaults will be used.
func SerializeSymmetricKeyEncryptedAEADReuseKey(w io.Writer, sessionKey []byte, passphrase []byte, aeadSupported bool, config *Config) (err error) {
var version int var version int
if config.AEAD() != nil { if aeadSupported {
version = 5 version = 6
} else { } else {
version = 4 version = 4
} }
@@ -203,11 +235,15 @@ func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, pass
switch version { switch version {
case 4: case 4:
packetLength = 2 /* header */ + len(s2kBytes) + 1 /* cipher type */ + keySize packetLength = 2 /* header */ + len(s2kBytes) + 1 /* cipher type */ + keySize
case 5: case 5, 6:
ivLen := config.AEAD().Mode().IvLength() ivLen := config.AEAD().Mode().IvLength()
tagLen := config.AEAD().Mode().TagLength() tagLen := config.AEAD().Mode().TagLength()
packetLength = 3 + len(s2kBytes) + ivLen + keySize + tagLen packetLength = 3 + len(s2kBytes) + ivLen + keySize + tagLen
} }
if version > 5 {
packetLength += 2 // additional octet count fields
}
err = serializeHeader(w, packetTypeSymmetricKeyEncrypted, packetLength) err = serializeHeader(w, packetTypeSymmetricKeyEncrypted, packetLength)
if err != nil { if err != nil {
return return
@@ -216,13 +252,22 @@ func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, pass
// Symmetric Key Encrypted Version // Symmetric Key Encrypted Version
buf := []byte{byte(version)} buf := []byte{byte(version)}
if version > 5 {
// Scalar octet count
buf = append(buf, byte(3+len(s2kBytes)+config.AEAD().Mode().IvLength()))
}
// Cipher function // Cipher function
buf = append(buf, byte(cipherFunc)) buf = append(buf, byte(cipherFunc))
if version == 5 { if version >= 5 {
// AEAD mode // AEAD mode
buf = append(buf, byte(config.AEAD().Mode())) buf = append(buf, byte(config.AEAD().Mode()))
} }
if version > 5 {
// Scalar octet count
buf = append(buf, byte(len(s2kBytes)))
}
_, err = w.Write(buf) _, err = w.Write(buf)
if err != nil { if err != nil {
return return
@@ -243,10 +288,10 @@ func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, pass
if err != nil { if err != nil {
return return
} }
case 5: case 5, 6:
mode := config.AEAD().Mode() mode := config.AEAD().Mode()
adata := []byte{0xc3, byte(5), byte(cipherFunc), byte(mode)} adata := []byte{0xc3, byte(version), byte(cipherFunc), byte(mode)}
aead := getEncryptedKeyAeadInstance(cipherFunc, mode, keyEncryptingKey, adata) aead := getEncryptedKeyAeadInstance(cipherFunc, mode, keyEncryptingKey, adata, version)
// Sample iv using random reader // Sample iv using random reader
iv := make([]byte, config.AEAD().Mode().IvLength()) iv := make([]byte, config.AEAD().Mode().IvLength())
@@ -270,7 +315,17 @@ func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, pass
return return
} }
func getEncryptedKeyAeadInstance(c CipherFunction, mode AEADMode, inputKey, associatedData []byte) (aead cipher.AEAD) { func getEncryptedKeyAeadInstance(c CipherFunction, mode AEADMode, inputKey, associatedData []byte, version int) (aead cipher.AEAD) {
blockCipher := c.new(inputKey) var blockCipher cipher.Block
if version > 5 {
hkdfReader := hkdf.New(sha256.New, inputKey, []byte{}, associatedData)
encryptionKey := make([]byte, c.KeySize())
_, _ = readFull(hkdfReader, encryptionKey)
blockCipher = c.new(encryptionKey)
} else {
blockCipher = c.new(inputKey)
}
return mode.new(blockCipher) return mode.new(blockCipher)
} }

View File

@@ -74,6 +74,10 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read
// SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet // SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet
// to w and returns a WriteCloser to which the to-be-encrypted packets can be // to w and returns a WriteCloser to which the to-be-encrypted packets can be
// written. // written.
// If aeadSupported is set to true, SEIPDv2 is used with the indicated CipherSuite.
// Otherwise, SEIPDv1 is used with the indicated CipherFunction.
// Note: aeadSupported MUST match the value passed to SerializeEncryptedKeyAEAD
// and/or SerializeSymmetricKeyEncryptedAEADReuseKey.
// If config is nil, sensible defaults will be used. // If config is nil, sensible defaults will be used.
func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, aeadSupported bool, cipherSuite CipherSuite, key []byte, config *Config) (Contents io.WriteCloser, err error) { func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, aeadSupported bool, cipherSuite CipherSuite, key []byte, config *Config) (Contents io.WriteCloser, err error) {
writeCloser := noOpCloser{w} writeCloser := noOpCloser{w}

View File

@@ -7,7 +7,9 @@ package packet
import ( import (
"crypto/cipher" "crypto/cipher"
"crypto/sha256" "crypto/sha256"
"fmt"
"io" "io"
"strconv"
"github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/errors"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
@@ -25,19 +27,19 @@ func (se *SymmetricallyEncrypted) parseAead(r io.Reader) error {
se.Cipher = CipherFunction(headerData[0]) se.Cipher = CipherFunction(headerData[0])
// cipherFunc must have block size 16 to use AEAD // cipherFunc must have block size 16 to use AEAD
if se.Cipher.blockSize() != 16 { if se.Cipher.blockSize() != 16 {
return errors.UnsupportedError("invalid aead cipher: " + string(se.Cipher)) return errors.UnsupportedError("invalid aead cipher: " + strconv.Itoa(int(se.Cipher)))
} }
// Mode // Mode
se.Mode = AEADMode(headerData[1]) se.Mode = AEADMode(headerData[1])
if se.Mode.TagLength() == 0 { if se.Mode.TagLength() == 0 {
return errors.UnsupportedError("unknown aead mode: " + string(se.Mode)) return errors.UnsupportedError("unknown aead mode: " + strconv.Itoa(int(se.Mode)))
} }
// Chunk size // Chunk size
se.ChunkSizeByte = headerData[2] se.ChunkSizeByte = headerData[2]
if se.ChunkSizeByte > 16 { if se.ChunkSizeByte > 16 {
return errors.UnsupportedError("invalid aead chunk size byte: " + string(se.ChunkSizeByte)) return errors.UnsupportedError("invalid aead chunk size byte: " + strconv.Itoa(int(se.ChunkSizeByte)))
} }
// Salt // Salt
@@ -62,11 +64,16 @@ func (se *SymmetricallyEncrypted) associatedData() []byte {
// decryptAead decrypts a V2 SEIPD packet (AEAD) as specified in // decryptAead decrypts a V2 SEIPD packet (AEAD) as specified in
// https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-07.html#section-5.13.2 // https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-07.html#section-5.13.2
func (se *SymmetricallyEncrypted) decryptAead(inputKey []byte) (io.ReadCloser, error) { func (se *SymmetricallyEncrypted) decryptAead(inputKey []byte) (io.ReadCloser, error) {
aead, nonce := getSymmetricallyEncryptedAeadInstance(se.Cipher, se.Mode, inputKey, se.Salt[:], se.associatedData()) if se.Cipher.KeySize() != len(inputKey) {
return nil, errors.StructuralError(fmt.Sprintf("invalid session key length for cipher: got %d bytes, but expected %d bytes", len(inputKey), se.Cipher.KeySize()))
}
aead, nonce := getSymmetricallyEncryptedAeadInstance(se.Cipher, se.Mode, inputKey, se.Salt[:], se.associatedData())
// Carry the first tagLen bytes // Carry the first tagLen bytes
chunkSize := decodeAEADChunkSize(se.ChunkSizeByte)
tagLen := se.Mode.TagLength() tagLen := se.Mode.TagLength()
peekedBytes := make([]byte, tagLen) chunkBytes := make([]byte, chunkSize+tagLen*2)
peekedBytes := chunkBytes[chunkSize+tagLen:]
n, err := io.ReadFull(se.Contents, peekedBytes) n, err := io.ReadFull(se.Contents, peekedBytes)
if n < tagLen || (err != nil && err != io.EOF) { if n < tagLen || (err != nil && err != io.EOF) {
return nil, errors.StructuralError("not enough data to decrypt:" + err.Error()) return nil, errors.StructuralError("not enough data to decrypt:" + err.Error())
@@ -76,12 +83,13 @@ func (se *SymmetricallyEncrypted) decryptAead(inputKey []byte) (io.ReadCloser, e
aeadCrypter: aeadCrypter{ aeadCrypter: aeadCrypter{
aead: aead, aead: aead,
chunkSize: decodeAEADChunkSize(se.ChunkSizeByte), chunkSize: decodeAEADChunkSize(se.ChunkSizeByte),
initialNonce: nonce, nonce: nonce,
associatedData: se.associatedData(), associatedData: se.associatedData(),
chunkIndex: make([]byte, 8), chunkIndex: nonce[len(nonce)-8:],
packetTag: packetTypeSymmetricallyEncryptedIntegrityProtected, packetTag: packetTypeSymmetricallyEncryptedIntegrityProtected,
}, },
reader: se.Contents, reader: se.Contents,
chunkBytes: chunkBytes,
peekedBytes: peekedBytes, peekedBytes: peekedBytes,
}, nil }, nil
} }
@@ -115,7 +123,7 @@ func serializeSymmetricallyEncryptedAead(ciphertext io.WriteCloser, cipherSuite
// Random salt // Random salt
salt := make([]byte, aeadSaltSize) salt := make([]byte, aeadSaltSize)
if _, err := rand.Read(salt); err != nil { if _, err := io.ReadFull(rand, salt); err != nil {
return nil, err return nil, err
} }
@@ -125,16 +133,20 @@ func serializeSymmetricallyEncryptedAead(ciphertext io.WriteCloser, cipherSuite
aead, nonce := getSymmetricallyEncryptedAeadInstance(cipherSuite.Cipher, cipherSuite.Mode, inputKey, salt, prefix) aead, nonce := getSymmetricallyEncryptedAeadInstance(cipherSuite.Cipher, cipherSuite.Mode, inputKey, salt, prefix)
chunkSize := decodeAEADChunkSize(chunkSizeByte)
tagLen := aead.Overhead()
chunkBytes := make([]byte, chunkSize+tagLen)
return &aeadEncrypter{ return &aeadEncrypter{
aeadCrypter: aeadCrypter{ aeadCrypter: aeadCrypter{
aead: aead, aead: aead,
chunkSize: decodeAEADChunkSize(chunkSizeByte), chunkSize: chunkSize,
associatedData: prefix, associatedData: prefix,
chunkIndex: make([]byte, 8), nonce: nonce,
initialNonce: nonce, chunkIndex: nonce[len(nonce)-8:],
packetTag: packetTypeSymmetricallyEncryptedIntegrityProtected, packetTag: packetTypeSymmetricallyEncryptedIntegrityProtected,
}, },
writer: ciphertext, writer: ciphertext,
chunkBytes: chunkBytes,
}, nil }, nil
} }
@@ -144,10 +156,10 @@ func getSymmetricallyEncryptedAeadInstance(c CipherFunction, mode AEADMode, inpu
encryptionKey := make([]byte, c.KeySize()) encryptionKey := make([]byte, c.KeySize())
_, _ = readFull(hkdfReader, encryptionKey) _, _ = readFull(hkdfReader, encryptionKey)
// Last 64 bits of nonce are the counter nonce = make([]byte, mode.IvLength())
nonce = make([]byte, mode.IvLength()-8)
_, _ = readFull(hkdfReader, nonce) // Last 64 bits of nonce are the counter
_, _ = readFull(hkdfReader, nonce[:len(nonce)-8])
blockCipher := c.new(encryptionKey) blockCipher := c.new(encryptionKey)
aead = mode.new(blockCipher) aead = mode.new(blockCipher)

View File

@@ -148,7 +148,7 @@ const mdcPacketTagByte = byte(0x80) | 0x40 | 19
func (ser *seMDCReader) Close() error { func (ser *seMDCReader) Close() error {
if ser.error { if ser.error {
return errors.ErrMDCMissing return errors.ErrMDCHashMismatch
} }
for !ser.eof { for !ser.eof {
@@ -159,7 +159,7 @@ func (ser *seMDCReader) Close() error {
break break
} }
if err != nil { if err != nil {
return errors.ErrMDCMissing return errors.ErrMDCHashMismatch
} }
} }
@@ -172,7 +172,7 @@ func (ser *seMDCReader) Close() error {
// The hash already includes the MDC header, but we still check its value // The hash already includes the MDC header, but we still check its value
// to confirm encryption correctness // to confirm encryption correctness
if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size { if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size {
return errors.ErrMDCMissing return errors.ErrMDCHashMismatch
} }
return nil return nil
} }
@@ -237,9 +237,9 @@ func serializeSymmetricallyEncryptedMdc(ciphertext io.WriteCloser, c CipherFunct
block := c.new(key) block := c.new(key)
blockSize := block.BlockSize() blockSize := block.BlockSize()
iv := make([]byte, blockSize) iv := make([]byte, blockSize)
_, err = config.Random().Read(iv) _, err = io.ReadFull(config.Random(), iv)
if err != nil { if err != nil {
return return nil, err
} }
s, prefix := NewOCFBEncrypter(block, iv, OCFBNoResync) s, prefix := NewOCFBEncrypter(block, iv, OCFBNoResync)
_, err = ciphertext.Write(prefix) _, err = ciphertext.Write(prefix)

View File

@@ -9,7 +9,6 @@ import (
"image" "image"
"image/jpeg" "image/jpeg"
"io" "io"
"io/ioutil"
) )
const UserAttrImageSubpacket = 1 const UserAttrImageSubpacket = 1
@@ -63,7 +62,7 @@ func NewUserAttribute(contents ...*OpaqueSubpacket) *UserAttribute {
func (uat *UserAttribute) parse(r io.Reader) (err error) { func (uat *UserAttribute) parse(r io.Reader) (err error) {
// RFC 4880, section 5.13 // RFC 4880, section 5.13
b, err := ioutil.ReadAll(r) b, err := io.ReadAll(r)
if err != nil { if err != nil {
return return
} }

View File

@@ -6,7 +6,6 @@ package packet
import ( import (
"io" "io"
"io/ioutil"
"strings" "strings"
) )
@@ -66,7 +65,7 @@ func NewUserId(name, comment, email string) *UserId {
func (uid *UserId) parse(r io.Reader) (err error) { func (uid *UserId) parse(r io.Reader) (err error) {
// RFC 4880, section 5.11 // RFC 4880, section 5.11
b, err := ioutil.ReadAll(r) b, err := io.ReadAll(r)
if err != nil { if err != nil {
return return
} }

View File

@@ -46,6 +46,7 @@ type MessageDetails struct {
DecryptedWith Key // the private key used to decrypt the message, if any. DecryptedWith Key // the private key used to decrypt the message, if any.
IsSigned bool // true if the message is signed. IsSigned bool // true if the message is signed.
SignedByKeyId uint64 // the key id of the signer, if any. SignedByKeyId uint64 // the key id of the signer, if any.
SignedByFingerprint []byte // the key fingerprint of the signer, if any.
SignedBy *Key // the key of the signer, if available. SignedBy *Key // the key of the signer, if available.
LiteralData *packet.LiteralData // the metadata of the contents LiteralData *packet.LiteralData // the metadata of the contents
UnverifiedBody io.Reader // the contents of the message. UnverifiedBody io.Reader // the contents of the message.
@@ -117,7 +118,7 @@ ParsePackets:
// This packet contains the decryption key encrypted to a public key. // This packet contains the decryption key encrypted to a public key.
md.EncryptedToKeyIds = append(md.EncryptedToKeyIds, p.KeyId) md.EncryptedToKeyIds = append(md.EncryptedToKeyIds, p.KeyId)
switch p.Algo { switch p.Algo {
case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSAEncryptOnly, packet.PubKeyAlgoElGamal, packet.PubKeyAlgoECDH: case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSAEncryptOnly, packet.PubKeyAlgoElGamal, packet.PubKeyAlgoECDH, packet.PubKeyAlgoX25519, packet.PubKeyAlgoX448:
break break
default: default:
continue continue
@@ -232,7 +233,7 @@ FindKey:
} }
mdFinal, sensitiveParsingErr := readSignedMessage(packets, md, keyring, config) mdFinal, sensitiveParsingErr := readSignedMessage(packets, md, keyring, config)
if sensitiveParsingErr != nil { if sensitiveParsingErr != nil {
return nil, errors.StructuralError("parsing error") return nil, errors.HandleSensitiveParsingError(sensitiveParsingErr, md.decrypted != nil)
} }
return mdFinal, nil return mdFinal, nil
} }
@@ -258,7 +259,7 @@ FindLiteralData:
} }
switch p := p.(type) { switch p := p.(type) {
case *packet.Compressed: case *packet.Compressed:
if err := packets.Push(p.Body); err != nil { if err := packets.Push(p.LimitedBodyReader(config.DecompressedMessageSizeLimit())); err != nil {
return nil, err return nil, err
} }
case *packet.OnePassSignature: case *packet.OnePassSignature:
@@ -270,13 +271,17 @@ FindLiteralData:
prevLast = true prevLast = true
} }
h, wrappedHash, err = hashForSignature(p.Hash, p.SigType) h, wrappedHash, err = hashForSignature(p.Hash, p.SigType, p.Salt)
if err != nil { if err != nil {
md.SignatureError = err md.SignatureError = err
} }
md.IsSigned = true md.IsSigned = true
if p.Version == 6 {
md.SignedByFingerprint = p.KeyFingerprint
}
md.SignedByKeyId = p.KeyId md.SignedByKeyId = p.KeyId
if keyring != nil { if keyring != nil {
keys := keyring.KeysByIdUsage(p.KeyId, packet.KeyFlagSign) keys := keyring.KeysByIdUsage(p.KeyId, packet.KeyFlagSign)
if len(keys) > 0 { if len(keys) > 0 {
@@ -292,7 +297,7 @@ FindLiteralData:
if md.IsSigned && md.SignatureError == nil { if md.IsSigned && md.SignatureError == nil {
md.UnverifiedBody = &signatureCheckReader{packets, h, wrappedHash, md, config} md.UnverifiedBody = &signatureCheckReader{packets, h, wrappedHash, md, config}
} else if md.decrypted != nil { } else if md.decrypted != nil {
md.UnverifiedBody = checkReader{md} md.UnverifiedBody = &checkReader{md, false}
} else { } else {
md.UnverifiedBody = md.LiteralData.Body md.UnverifiedBody = md.LiteralData.Body
} }
@@ -300,12 +305,22 @@ FindLiteralData:
return md, nil return md, nil
} }
func wrapHashForSignature(hashFunc hash.Hash, sigType packet.SignatureType) (hash.Hash, error) {
switch sigType {
case packet.SigTypeBinary:
return hashFunc, nil
case packet.SigTypeText:
return NewCanonicalTextHash(hashFunc), nil
}
return nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType)))
}
// hashForSignature returns a pair of hashes that can be used to verify a // hashForSignature returns a pair of hashes that can be used to verify a
// signature. The signature may specify that the contents of the signed message // signature. The signature may specify that the contents of the signed message
// should be preprocessed (i.e. to normalize line endings). Thus this function // should be preprocessed (i.e. to normalize line endings). Thus this function
// returns two hashes. The second should be used to hash the message itself and // returns two hashes. The second should be used to hash the message itself and
// performs any needed preprocessing. // performs any needed preprocessing.
func hashForSignature(hashFunc crypto.Hash, sigType packet.SignatureType) (hash.Hash, hash.Hash, error) { func hashForSignature(hashFunc crypto.Hash, sigType packet.SignatureType, sigSalt []byte) (hash.Hash, hash.Hash, error) {
if _, ok := algorithm.HashToHashIdWithSha1(hashFunc); !ok { if _, ok := algorithm.HashToHashIdWithSha1(hashFunc); !ok {
return nil, nil, errors.UnsupportedError("unsupported hash function") return nil, nil, errors.UnsupportedError("unsupported hash function")
} }
@@ -313,14 +328,19 @@ func hashForSignature(hashFunc crypto.Hash, sigType packet.SignatureType) (hash.
return nil, nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hashFunc))) return nil, nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hashFunc)))
} }
h := hashFunc.New() h := hashFunc.New()
if sigSalt != nil {
h.Write(sigSalt)
}
wrappedHash, err := wrapHashForSignature(h, sigType)
if err != nil {
return nil, nil, err
}
switch sigType { switch sigType {
case packet.SigTypeBinary: case packet.SigTypeBinary:
return h, h, nil return h, wrappedHash, nil
case packet.SigTypeText: case packet.SigTypeText:
return h, NewCanonicalTextHash(h), nil return h, wrappedHash, nil
} }
return nil, nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType))) return nil, nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType)))
} }
@@ -328,21 +348,27 @@ func hashForSignature(hashFunc crypto.Hash, sigType packet.SignatureType) (hash.
// it closes the ReadCloser from any SymmetricallyEncrypted packet to trigger // it closes the ReadCloser from any SymmetricallyEncrypted packet to trigger
// MDC checks. // MDC checks.
type checkReader struct { type checkReader struct {
md *MessageDetails md *MessageDetails
checked bool
} }
func (cr checkReader) Read(buf []byte) (int, error) { func (cr *checkReader) Read(buf []byte) (int, error) {
n, sensitiveParsingError := cr.md.LiteralData.Body.Read(buf) n, sensitiveParsingError := cr.md.LiteralData.Body.Read(buf)
if sensitiveParsingError == io.EOF { if sensitiveParsingError == io.EOF {
if cr.checked {
// Only check once
return n, io.EOF
}
mdcErr := cr.md.decrypted.Close() mdcErr := cr.md.decrypted.Close()
if mdcErr != nil { if mdcErr != nil {
return n, mdcErr return n, mdcErr
} }
cr.checked = true
return n, io.EOF return n, io.EOF
} }
if sensitiveParsingError != nil { if sensitiveParsingError != nil {
return n, errors.StructuralError("parsing error") return n, errors.HandleSensitiveParsingError(sensitiveParsingError, true)
} }
return n, nil return n, nil
@@ -366,6 +392,7 @@ func (scr *signatureCheckReader) Read(buf []byte) (int, error) {
scr.wrappedHash.Write(buf[:n]) scr.wrappedHash.Write(buf[:n])
} }
readsDecryptedData := scr.md.decrypted != nil
if sensitiveParsingError == io.EOF { if sensitiveParsingError == io.EOF {
var p packet.Packet var p packet.Packet
var readError error var readError error
@@ -384,7 +411,7 @@ func (scr *signatureCheckReader) Read(buf []byte) (int, error) {
key := scr.md.SignedBy key := scr.md.SignedBy
signatureError := key.PublicKey.VerifySignature(scr.h, sig) signatureError := key.PublicKey.VerifySignature(scr.h, sig)
if signatureError == nil { if signatureError == nil {
signatureError = checkSignatureDetails(key, sig, scr.config) signatureError = checkMessageSignatureDetails(key, sig, scr.config)
} }
scr.md.Signature = sig scr.md.Signature = sig
scr.md.SignatureError = signatureError scr.md.SignatureError = signatureError
@@ -408,16 +435,15 @@ func (scr *signatureCheckReader) Read(buf []byte) (int, error) {
// unsigned hash of its own. In order to check this we need to // unsigned hash of its own. In order to check this we need to
// close that Reader. // close that Reader.
if scr.md.decrypted != nil { if scr.md.decrypted != nil {
mdcErr := scr.md.decrypted.Close() if sensitiveParsingError := scr.md.decrypted.Close(); sensitiveParsingError != nil {
if mdcErr != nil { return n, errors.HandleSensitiveParsingError(sensitiveParsingError, true)
return n, mdcErr
} }
} }
return n, io.EOF return n, io.EOF
} }
if sensitiveParsingError != nil { if sensitiveParsingError != nil {
return n, errors.StructuralError("parsing error") return n, errors.HandleSensitiveParsingError(sensitiveParsingError, readsDecryptedData)
} }
return n, nil return n, nil
@@ -428,14 +454,13 @@ func (scr *signatureCheckReader) Read(buf []byte) (int, error) {
// if any, and a possible signature verification error. // if any, and a possible signature verification error.
// If the signer isn't known, ErrUnknownIssuer is returned. // If the signer isn't known, ErrUnknownIssuer is returned.
func VerifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) { func VerifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) {
var expectedHashes []crypto.Hash return verifyDetachedSignature(keyring, signed, signature, nil, false, config)
return verifyDetachedSignature(keyring, signed, signature, expectedHashes, config)
} }
// VerifyDetachedSignatureAndHash performs the same actions as // VerifyDetachedSignatureAndHash performs the same actions as
// VerifyDetachedSignature and checks that the expected hash functions were used. // VerifyDetachedSignature and checks that the expected hash functions were used.
func VerifyDetachedSignatureAndHash(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) { func VerifyDetachedSignatureAndHash(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) {
return verifyDetachedSignature(keyring, signed, signature, expectedHashes, config) return verifyDetachedSignature(keyring, signed, signature, expectedHashes, true, config)
} }
// CheckDetachedSignature takes a signed file and a detached signature and // CheckDetachedSignature takes a signed file and a detached signature and
@@ -443,25 +468,24 @@ func VerifyDetachedSignatureAndHash(keyring KeyRing, signed, signature io.Reader
// signature verification error. If the signer isn't known, // signature verification error. If the signer isn't known,
// ErrUnknownIssuer is returned. // ErrUnknownIssuer is returned.
func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader, config *packet.Config) (signer *Entity, err error) { func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader, config *packet.Config) (signer *Entity, err error) {
var expectedHashes []crypto.Hash _, signer, err = verifyDetachedSignature(keyring, signed, signature, nil, false, config)
return CheckDetachedSignatureAndHash(keyring, signed, signature, expectedHashes, config) return
} }
// CheckDetachedSignatureAndHash performs the same actions as // CheckDetachedSignatureAndHash performs the same actions as
// CheckDetachedSignature and checks that the expected hash functions were used. // CheckDetachedSignature and checks that the expected hash functions were used.
func CheckDetachedSignatureAndHash(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, config *packet.Config) (signer *Entity, err error) { func CheckDetachedSignatureAndHash(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, config *packet.Config) (signer *Entity, err error) {
_, signer, err = verifyDetachedSignature(keyring, signed, signature, expectedHashes, config) _, signer, err = verifyDetachedSignature(keyring, signed, signature, expectedHashes, true, config)
return return
} }
func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) { func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, checkHashes bool, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) {
var issuerKeyId uint64 var issuerKeyId uint64
var hashFunc crypto.Hash var hashFunc crypto.Hash
var sigType packet.SignatureType var sigType packet.SignatureType
var keys []Key var keys []Key
var p packet.Packet var p packet.Packet
expectedHashesLen := len(expectedHashes)
packets := packet.NewReader(signature) packets := packet.NewReader(signature)
for { for {
p, err = packets.Next() p, err = packets.Next()
@@ -483,16 +507,19 @@ func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expec
issuerKeyId = *sig.IssuerKeyId issuerKeyId = *sig.IssuerKeyId
hashFunc = sig.Hash hashFunc = sig.Hash
sigType = sig.SigType sigType = sig.SigType
if checkHashes {
for i, expectedHash := range expectedHashes { matchFound := false
if hashFunc == expectedHash { // check for hashes
break for _, expectedHash := range expectedHashes {
if hashFunc == expectedHash {
matchFound = true
break
}
} }
if i+1 == expectedHashesLen { if !matchFound {
return nil, nil, errors.StructuralError("hash algorithm mismatch with cleartext message headers") return nil, nil, errors.StructuralError("hash algorithm or salt mismatch with cleartext message headers")
} }
} }
keys = keyring.KeysByIdUsage(issuerKeyId, packet.KeyFlagSign) keys = keyring.KeysByIdUsage(issuerKeyId, packet.KeyFlagSign)
if len(keys) > 0 { if len(keys) > 0 {
break break
@@ -503,7 +530,11 @@ func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expec
panic("unreachable") panic("unreachable")
} }
h, wrappedHash, err := hashForSignature(hashFunc, sigType) h, err := sig.PrepareVerify()
if err != nil {
return nil, nil, err
}
wrappedHash, err := wrapHashForSignature(h, sigType)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -515,7 +546,7 @@ func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expec
for _, key := range keys { for _, key := range keys {
err = key.PublicKey.VerifySignature(h, sig) err = key.PublicKey.VerifySignature(h, sig)
if err == nil { if err == nil {
return sig, key.Entity, checkSignatureDetails(&key, sig, config) return sig, key.Entity, checkMessageSignatureDetails(&key, sig, config)
} }
} }
@@ -533,7 +564,7 @@ func CheckArmoredDetachedSignature(keyring KeyRing, signed, signature io.Reader,
return CheckDetachedSignature(keyring, signed, body, config) return CheckDetachedSignature(keyring, signed, body, config)
} }
// checkSignatureDetails returns an error if: // checkMessageSignatureDetails returns an error if:
// - The signature (or one of the binding signatures mentioned below) // - The signature (or one of the binding signatures mentioned below)
// has a unknown critical notation data subpacket // has a unknown critical notation data subpacket
// - The primary key of the signing entity is revoked // - The primary key of the signing entity is revoked
@@ -551,15 +582,11 @@ func CheckArmoredDetachedSignature(keyring KeyRing, signed, signature io.Reader,
// NOTE: The order of these checks is important, as the caller may choose to // NOTE: The order of these checks is important, as the caller may choose to
// ignore ErrSignatureExpired or ErrKeyExpired errors, but should never // ignore ErrSignatureExpired or ErrKeyExpired errors, but should never
// ignore any other errors. // ignore any other errors.
// func checkMessageSignatureDetails(key *Key, signature *packet.Signature, config *packet.Config) error {
// TODO: Also return an error if:
// - The primary key is expired according to a direct-key signature
// - (For V5 keys only:) The direct-key signature (exists and) is expired
func checkSignatureDetails(key *Key, signature *packet.Signature, config *packet.Config) error {
now := config.Now() now := config.Now()
primaryIdentity := key.Entity.PrimaryIdentity() primarySelfSignature, primaryIdentity := key.Entity.PrimarySelfSignature()
signedBySubKey := key.PublicKey != key.Entity.PrimaryKey signedBySubKey := key.PublicKey != key.Entity.PrimaryKey
sigsToCheck := []*packet.Signature{signature, primaryIdentity.SelfSignature} sigsToCheck := []*packet.Signature{signature, primarySelfSignature}
if signedBySubKey { if signedBySubKey {
sigsToCheck = append(sigsToCheck, key.SelfSignature, key.SelfSignature.EmbeddedSignature) sigsToCheck = append(sigsToCheck, key.SelfSignature, key.SelfSignature.EmbeddedSignature)
} }
@@ -572,10 +599,10 @@ func checkSignatureDetails(key *Key, signature *packet.Signature, config *packet
} }
if key.Entity.Revoked(now) || // primary key is revoked if key.Entity.Revoked(now) || // primary key is revoked
(signedBySubKey && key.Revoked(now)) || // subkey is revoked (signedBySubKey && key.Revoked(now)) || // subkey is revoked
primaryIdentity.Revoked(now) { // primary identity is revoked (primaryIdentity != nil && primaryIdentity.Revoked(now)) { // primary identity is revoked for v4
return errors.ErrKeyRevoked return errors.ErrKeyRevoked
} }
if key.Entity.PrimaryKey.KeyExpired(primaryIdentity.SelfSignature, now) { // primary key is expired if key.Entity.PrimaryKey.KeyExpired(primarySelfSignature, now) { // primary key is expired
return errors.ErrKeyExpired return errors.ErrKeyExpired
} }
if signedBySubKey { if signedBySubKey {

Some files were not shown because too many files have changed in this diff Show More