7
7
"go/ast"
8
8
"go/printer"
9
9
"go/token"
10
- "go/types "
10
+ "slices "
11
11
12
12
"golang.org/x/tools/go/analysis"
13
13
"golang.org/x/tools/go/analysis/passes/inspect"
@@ -169,13 +169,14 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
169
169
continue
170
170
}
171
171
172
+ // Ignore [context.Background] & [context.TODO].
173
+ if isContextFunction (assignStmt .Rhs [0 ], "Background" , "TODO" ) {
174
+ continue
175
+ }
176
+
172
177
// allow assignment to non-pointer children of values defined within the loop
173
- if lhs := getRootIdent (pass , assignStmt .Lhs [0 ]); lhs != nil {
174
- if obj := pass .TypesInfo .ObjectOf (lhs ); obj != nil {
175
- if checkObjectScopeWithinNode (obj .Parent (), node ) {
176
- continue // definition is within the loop
177
- }
178
- }
178
+ if isWithinLoop (assignStmt .Lhs [0 ], node , pass ) {
179
+ continue
179
180
}
180
181
181
182
return assignStmt
@@ -184,16 +185,51 @@ func findNestedContext(pass *analysis.Pass, node ast.Node, stmts []ast.Stmt) *as
184
185
return nil
185
186
}
186
187
187
- func checkObjectScopeWithinNode (scope * types.Scope , node ast.Node ) bool {
188
- if scope == nil {
188
+ // render returns the pretty-print of the given node
189
+ func render (fset * token.FileSet , x interface {}) ([]byte , error ) {
190
+ var buf bytes.Buffer
191
+ if err := printer .Fprint (& buf , fset , x ); err != nil {
192
+ return nil , fmt .Errorf ("printing node: %w" , err )
193
+ }
194
+ return buf .Bytes (), nil
195
+ }
196
+
197
+ func isContextFunction (exp ast.Expr , fnName ... string ) bool {
198
+ call , ok := exp .(* ast.CallExpr )
199
+ if ! ok {
200
+ return false
201
+ }
202
+
203
+ selector , ok := call .Fun .(* ast.SelectorExpr )
204
+ if ! ok {
205
+ return false
206
+ }
207
+
208
+ ident , ok := selector .X .(* ast.Ident )
209
+ if ! ok {
210
+ return false
211
+ }
212
+
213
+ return ident .Name == "context" && slices .Contains (fnName , selector .Sel .Name )
214
+ }
215
+
216
+ func isWithinLoop (exp ast.Expr , node ast.Node , pass * analysis.Pass ) bool {
217
+ lhs := getRootIdent (pass , exp )
218
+ if lhs == nil {
189
219
return false
190
220
}
191
221
192
- if scope .Pos () >= node .Pos () && scope .End () <= node .End () {
193
- return true
222
+ obj := pass .TypesInfo .ObjectOf (lhs )
223
+ if obj == nil {
224
+ return false
225
+ }
226
+
227
+ scope := obj .Parent ()
228
+ if scope == nil {
229
+ return false
194
230
}
195
231
196
- return false
232
+ return scope . Pos () >= node . Pos () && scope . End () <= node . End ()
197
233
}
198
234
199
235
func getRootIdent (pass * analysis.Pass , node ast.Node ) * ast.Ident {
@@ -213,12 +249,3 @@ func getRootIdent(pass *analysis.Pass, node ast.Node) *ast.Ident {
213
249
}
214
250
}
215
251
}
216
-
217
- // render returns the pretty-print of the given node
218
- func render (fset * token.FileSet , x interface {}) ([]byte , error ) {
219
- var buf bytes.Buffer
220
- if err := printer .Fprint (& buf , fset , x ); err != nil {
221
- return nil , fmt .Errorf ("printing node: %w" , err )
222
- }
223
- return buf .Bytes (), nil
224
- }
0 commit comments